From 2da00a6d9d206e420ca71decafd1ec46fd10d01b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 11:16:42 -0400 Subject: [PATCH 01/64] Remove legacy c++ standalone debug program --- debug/api_debug.cpp | 810 -------------------------------- debug/data/heterosked_test.csv | 101 ---- debug/data/heterosked_train.csv | 401 ---------------- 3 files changed, 1312 deletions(-) delete mode 100644 debug/api_debug.cpp delete mode 100644 debug/data/heterosked_test.csv delete mode 100644 debug/data/heterosked_train.csv diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp deleted file mode 100644 index 5426fe7e..00000000 --- a/debug/api_debug.cpp +++ /dev/null @@ -1,810 +0,0 @@ -/*! Copyright (c) 2024 stochtree authors*/ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace StochTree{ - -void GenerateDGP1(std::vector& covariates, std::vector& basis, std::vector& outcome, std::vector& rfx_basis, std::vector& rfx_groups, std::vector& feature_types, std::mt19937& gen, int& n, int& x_cols, int& omega_cols, int& y_cols, int& rfx_basis_cols, int& num_rfx_groups, bool rfx_included, int random_seed = -1) { - // Data dimensions - n = 1000; - x_cols = 2; - omega_cols = 1; - y_cols = 1; - if (rfx_included) { - num_rfx_groups = 2; - rfx_basis_cols = 1; - } else { - num_rfx_groups = 0; - rfx_basis_cols = 0; - } - - // Resize data - covariates.resize(n * x_cols); - basis.resize(n * omega_cols); - rfx_basis.resize(n * rfx_basis_cols); - outcome.resize(n * y_cols); - rfx_groups.resize(n); - feature_types.resize(x_cols, FeatureType::kNumeric); - - // Random number generation - standard_normal normal_dist; - - // DGP parameters - std::vector betas{-10, -5, 5, 10}; - int num_partitions = betas.size(); - double f_x_omega; - double rfx; - double error; - - for (int i = 0; i < n; i++) { - for (int j = 0; j < x_cols; j++) { - covariates[i*x_cols + j] = standard_uniform_draw(gen); - } - - for (int j = 0; j < omega_cols; j++) { - basis[i*omega_cols + j] = standard_uniform_draw(gen); - } - - if (rfx_included) { - for (int j = 0; j < rfx_basis_cols; j++) { - rfx_basis[i * rfx_basis_cols + j] = 1; - } - - if (i % 2 == 0) { - rfx_groups[i] = 1; - } - else { - rfx_groups[i] = 2; - } - } - - for (int j = 0; j < y_cols; j++) { - if ((covariates[i * x_cols + 0] >= 0.0) && covariates[i * x_cols + 0] < 0.25) { - f_x_omega = betas[0] * basis[i * omega_cols + 0]; - } else if ((covariates[i * x_cols + 0] >= 0.25) && covariates[i * x_cols + 0] < 0.5) { - f_x_omega = betas[1] * basis[i * omega_cols + 0]; - } else if ((covariates[i * x_cols + 0] >= 0.5) && covariates[i * x_cols + 0] < 0.75) { - f_x_omega = betas[2] * basis[i * omega_cols + 0]; - } else { - f_x_omega = betas[3] * basis[i * omega_cols + 0]; - } - error = 0.1 * normal_dist(gen); - outcome[i * y_cols + j] = f_x_omega + error; - if (rfx_included) { - if (rfx_groups[i] == 1) { - rfx = 5.; - } - else { - rfx = -5.; - } - outcome[i * y_cols + j] += rfx; - } - } - } -} - -void int_to_binary_vector(int32_t input, std::vector& output, int32_t offset) { - bool terminated = false; - int numerator = input; - int vec_pos = 0; - if (numerator < 2) { - terminated = true; - output.at(offset + vec_pos) = numerator; - } - while (!terminated) { - std::div_t div_result = std::div(numerator, 2); - output.at(offset + vec_pos) = div_result.rem; - if (div_result.quot == 1) { - terminated = true; - output.at(offset + vec_pos + 1) = 1; - } else { - numerator = div_result.quot; - vec_pos += 1; - } - } -} - -void GenerateDGP2(std::vector& covariates, std::vector& basis, std::vector& outcome, std::vector& rfx_basis, std::vector& rfx_groups, std::vector& feature_types, std::mt19937& gen, int& n, int& x_cols, int& omega_cols, int& y_cols, int& rfx_basis_cols, int& num_rfx_groups, bool rfx_included, int random_seed = -1) { - // Data dimensions - int n1 = 50; - x_cols = 10; - int num_cells = std::pow(2, x_cols); - int p1 = 100; - if (p1 >= num_cells) Log::Fatal("p1 must be < 2^x_cols"); - n = n1*num_cells; - omega_cols = 0; - y_cols = 1; - if (rfx_included) { - num_rfx_groups = 2; - rfx_basis_cols = 1; - } - else { - num_rfx_groups = 0; - rfx_basis_cols = 0; - } - - // Resize data - covariates.resize(n * x_cols); - basis.resize(n * omega_cols); - rfx_basis.resize(n * rfx_basis_cols); - outcome.resize(n * y_cols); - rfx_groups.resize(n); - feature_types.resize(x_cols, FeatureType::kNumeric); - - // Random number generation - standard_normal normal_dist; - - // Generate a sequence of integers from 0 to num_cells - 1 - std::vector cell_run(num_cells); - std::iota(cell_run.begin(), cell_run.end(), 0); - - // Repeat this sequence n1 times as the "covariates" - std::vector cell_vector; - for (int i = 0; i < n1; i++) { - std::copy(cell_run.begin(), cell_run.end(), std::back_inserter(cell_vector)); - } - - // Convert cells to binary covariate columns (row-major) - std::vector covariates_binary(n * x_cols); - int32_t offset = 0; - for (size_t i = 0; i < n; i++) { - int_to_binary_vector(cell_vector.at(i), covariates_binary, offset); - offset += x_cols; - } - - // Add (folded) gaussian noise to the binary covariates - // std::vector covariates_numeric(n* x_cols); - std::vector noise1(n); - std::vector noise2(n); - int switch_flip; - for (size_t i = 0; i < n; i++) { - noise1.at(i) = std::abs(normal_dist(gen)); - noise2.at(i) = std::abs(normal_dist(gen)); - } - for (size_t i = 0; i < n; i++) { - for (int j = 0; j < x_cols; j++) { - switch_flip = covariates_binary.at(i * x_cols + j); - covariates.at(i * x_cols + j) = switch_flip * noise1.at(i) + (1 - switch_flip) * noise2.at(i); - } - } - - // DGP parameters - double intercept = 0.5; - std::vector cell_coefficients_sparse(p1-1); - for (int i = 0; i < cell_coefficients_sparse.size(); i++) { - cell_coefficients_sparse.at(i) = -10*std::abs(normal_dist(gen)); - } - std::vector cell_weights(num_cells, 1./num_cells); - std::vector cell_indices_sparse(p1 - 1); - walker_vose cell_selector(cell_weights.begin(), cell_weights.end()); - for (int i = 0; i < p1-1; i++) { - cell_indices_sparse.at(i) = cell_selector(gen); - } - std::vector cell_coefficients_full(num_cells, 0.); - for (int i = 0; i < p1 - 1; i++) { - cell_coefficients_full.at(cell_indices_sparse.at(i)) = cell_coefficients_sparse.at(i); - } - double f_x; - double rfx; - double error; - - // Outcome - for (int i = 0; i < n; i++) { - f_x = intercept + cell_coefficients_full.at(cell_vector.at(i)); - - if (rfx_included) { - for (int j = 0; j < rfx_basis_cols; j++) { - rfx_basis[i * rfx_basis_cols + j] = 1; - } - - if (i % 2 == 0) { - rfx_groups[i] = 1; - } - else { - rfx_groups[i] = 2; - } - } - - for (int j = 0; j < y_cols; j++) { - error = 0.1 * normal_dist(gen); - outcome[i * y_cols + j] = f_x + error; - if (rfx_included) { - if (rfx_groups[i] == 1) { - rfx = 5.; - } - else { - rfx = -5.; - } - outcome[i * y_cols + j] += rfx; - } - } - } -} - -void GenerateDGP3(std::vector& covariates, std::vector& basis, std::vector& outcome, std::vector& rfx_basis, std::vector& rfx_groups, std::vector& feature_types, std::mt19937& gen, int& n, int& x_cols, int& omega_cols, int& y_cols, int& rfx_basis_cols, int& num_rfx_groups, bool rfx_included, int random_seed = -1) { - // Data dimensions - n = 1000; - x_cols = 2; - omega_cols = 2; - y_cols = 1; - if (rfx_included) { - num_rfx_groups = 2; - rfx_basis_cols = 1; - } else { - num_rfx_groups = 0; - rfx_basis_cols = 0; - } - - // Resize data - covariates.resize(n * x_cols); - basis.resize(n * omega_cols); - rfx_basis.resize(n * rfx_basis_cols); - outcome.resize(n * y_cols); - rfx_groups.resize(n); - feature_types.resize(x_cols, FeatureType::kNumeric); - - // Random number generation - standard_normal normal_dist; - - // DGP parameters - std::vector betas{-10, -5, 5, 10}; - int num_partitions = betas.size(); - double f_x_omega; - double rfx; - double error; - - for (int i = 0; i < n; i++) { - for (int j = 0; j < x_cols; j++) { - covariates[i*x_cols + j] = standard_uniform_draw(gen); - } - - for (int j = 0; j < omega_cols; j++) { - basis[i*omega_cols + j] = standard_uniform_draw(gen); - } - - if (rfx_included) { - for (int j = 0; j < rfx_basis_cols; j++) { - rfx_basis[i * rfx_basis_cols + j] = 1; - } - - if (i % 2 == 0) { - rfx_groups[i] = 1; - } - else { - rfx_groups[i] = 2; - } - } - - for (int j = 0; j < y_cols; j++) { - if ((covariates[i * x_cols + 0] >= 0.0) && covariates[i * x_cols + 0] < 0.25) { - f_x_omega = betas[0] * basis[i * omega_cols + 0]; - } else if ((covariates[i * x_cols + 0] >= 0.25) && covariates[i * x_cols + 0] < 0.5) { - f_x_omega = betas[1] * basis[i * omega_cols + 0]; - } else if ((covariates[i * x_cols + 0] >= 0.5) && covariates[i * x_cols + 0] < 0.75) { - f_x_omega = betas[2] * basis[i * omega_cols + 0]; - } else { - f_x_omega = betas[3] * basis[i * omega_cols + 0]; - } - error = 0.1 * normal_dist(gen); - outcome[i * y_cols + j] = f_x_omega + error; - if (rfx_included) { - if (rfx_groups[i] == 1) { - rfx = 5.; - } - else { - rfx = -5.; - } - outcome[i * y_cols + j] += rfx; - } - } - } -} - -void GenerateDGP4(std::vector& covariates, std::vector& basis, std::vector& outcome, std::vector& rfx_basis, std::vector& rfx_groups, std::vector& feature_types, std::mt19937& gen, int& n, int& x_cols, int& omega_cols, int& y_cols, int& rfx_basis_cols, int& num_rfx_groups, bool rfx_included, int random_seed = -1) { - // Data dimensions - n = 400; - x_cols = 10; - omega_cols = 0; - y_cols = 1; - if (rfx_included) { - num_rfx_groups = 2; - rfx_basis_cols = 1; - } else { - num_rfx_groups = 0; - rfx_basis_cols = 0; - } - - // Resize data - covariates.resize(n * x_cols); - basis.resize(n * omega_cols); - rfx_basis.resize(n * rfx_basis_cols); - outcome.resize(n * y_cols); - rfx_groups.resize(n); - feature_types.resize(x_cols, FeatureType::kNumeric); - - // Random number generation - standard_normal normal_dist; - - // DGP parameters - std::vector betas{0.5, 1, 2, 3}; - int num_partitions = betas.size(); - double s_x; - double rfx; - double error; - - for (int i = 0; i < n; i++) { - for (int j = 0; j < x_cols; j++) { - covariates[i*x_cols + j] = standard_uniform_draw(gen); - } - - for (int j = 0; j < omega_cols; j++) { - basis[i*omega_cols + j] = standard_uniform_draw(gen); - } - - if (rfx_included) { - for (int j = 0; j < rfx_basis_cols; j++) { - rfx_basis[i * rfx_basis_cols + j] = 1; - } - - if (i % 2 == 0) { - rfx_groups[i] = 1; - } - else { - rfx_groups[i] = 2; - } - } - - for (int j = 0; j < y_cols; j++) { - if ((covariates[i * x_cols + 0] >= 0.0) && covariates[i * x_cols + 0] < 0.25) { - s_x = betas[0]; - } else if ((covariates[i * x_cols + 0] >= 0.25) && covariates[i * x_cols + 0] < 0.5) { - s_x = betas[1]; - } else if ((covariates[i * x_cols + 0] >= 0.5) && covariates[i * x_cols + 0] < 0.75) { - s_x = betas[2]; - } else { - s_x = betas[3]; - } - error = s_x * normal_dist(gen); - outcome[i * y_cols + j] = error; - if (rfx_included) { - if (rfx_groups[i] == 1) { - rfx = 5.; - } - else { - rfx = -5.; - } - outcome[i * y_cols + j] += rfx; - } - } - } -} - -void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double& outcome_scale) { - data_size_t n = residual.NumRows(); - double outcome_val = 0.0; - double outcome_sum = 0.0; - double outcome_sum_squares = 0.0; - double var_y = 0.0; - for (data_size_t i = 0; i < n; i++){ - outcome_val = residual.GetElement(i); - outcome_sum += outcome_val; - outcome_sum_squares += std::pow(outcome_val, 2.0); - } - var_y = outcome_sum_squares / static_cast(n) - std::pow(outcome_sum / static_cast(n), 2.0); - outcome_scale = std::sqrt(var_y); - outcome_offset = outcome_sum / static_cast(n); - double previous_residual; - for (data_size_t i = 0; i < n; i++){ - previous_residual = residual.GetElement(i); - residual.SetElement(i, (previous_residual - outcome_offset) / outcome_scale); - } -} - -void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussian, - bool rfx_included = false, int num_gfr = 10, int num_mcmc = 100, int random_seed = -1, - std::string dataset_filename = "", int outcome_col = -1, std::string covariate_cols = "", - std::string basis_cols = "", int num_threads = -1) { - // Flag the data as row-major - bool row_major = true; - - // Determine whether we will generate data or read from file - bool data_from_file = false; - if (!dataset_filename.empty()) { - data_from_file = true; - } - - // Random number generation - std::mt19937 gen; - if (random_seed == -1) { - std::random_device rd; - std::mt19937 gen(rd()); - } - else { - std::mt19937 gen(random_seed); - } - - // Initialize dataset - ForestDataset dataset = ForestDataset(); - - // Initialize outcome - ColumnVector residual = ColumnVector(); - - // Empty data containers and dimensions (filled in by calling a specific DGP simulation function below) - int n; - int x_cols; - int omega_cols; - int y_cols; - int num_rfx_groups; - int rfx_basis_cols; - std::vector covariates_raw; - std::vector basis_raw; - std::vector rfx_basis_raw; - std::vector outcome_raw; - std::vector rfx_groups; - std::vector feature_types; - - // Check for DGP : ModelType compatibility - if ((model_type != kConstantLeafGaussian) && (dgp_num == 1)) { - Log::Fatal("dgp 2 is only compatible with a constant leaf model"); - } - - // Generate the data - int output_dimension; - bool is_leaf_constant; - if (!data_from_file) { - if (dgp_num == 0) { - GenerateDGP1(covariates_raw, basis_raw, outcome_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); - dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); - dataset.AddBasis(basis_raw.data(), n, omega_cols, row_major); - output_dimension = 1; - is_leaf_constant = false; - } else if (dgp_num == 1) { - GenerateDGP2(covariates_raw, basis_raw, outcome_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); - dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); - output_dimension = 1; - is_leaf_constant = true; - } else if (dgp_num == 2) { - GenerateDGP3(covariates_raw, basis_raw, outcome_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); - dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); - dataset.AddBasis(basis_raw.data(), n, omega_cols, row_major); - output_dimension = omega_cols; - is_leaf_constant = false; - } else if (dgp_num == 3) { - GenerateDGP4(covariates_raw, basis_raw, outcome_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); - dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); - output_dimension = 1; - is_leaf_constant = true; - } else { - Log::Fatal("Invalid dgp_num"); - } - // Construct residual - residual = ColumnVector(outcome_raw.data(), n); - } else { - // Override RFX - rfx_included = false; - // Construct residual - residual = ColumnVector(dataset_filename, outcome_col); - y_cols = 0; - // Add covariates - dataset.AddCovariatesFromCSV(dataset_filename, covariate_cols); - n = dataset.NumObservations(); - x_cols = dataset.NumCovariates(); - feature_types.resize(x_cols, FeatureType::kNumeric); - if (!basis_cols.empty()) { - dataset.AddBasisFromCSV(dataset_filename, basis_cols); - output_dimension = dataset.NumBasis(); - is_leaf_constant = false; - omega_cols = dataset.NumBasis(); - } else { - output_dimension = 1; - is_leaf_constant = true; - omega_cols = 0; - } - } - - // Runtime check --- cannot have case / variance weights and be modeling heteroskedastic variance - if ((dgp_num == 3) && (dataset.HasVarWeights())) { - StochTree::Log::Fatal("Cannot provide variance / case weights when modeling heteroskedasticity with a forest"); - } - - // Center and scale the data - double outcome_offset; - double outcome_scale; - OutcomeOffsetScale(residual, outcome_offset, outcome_scale); - - // Prepare random effects sampling (if desired) - RandomEffectsDataset rfx_dataset; - std::vector rfx_init(n, 0); - RandomEffectsTracker rfx_tracker = RandomEffectsTracker(rfx_init); - MultivariateRegressionRandomEffectsModel rfx_model = MultivariateRegressionRandomEffectsModel(1, 1); - RandomEffectsContainer rfx_container; - LabelMapper label_mapper; - if (rfx_included) { - // Construct a random effects dataset - rfx_dataset = RandomEffectsDataset(); - rfx_dataset.AddBasis(rfx_basis_raw.data(), n, rfx_basis_cols, true); - rfx_dataset.AddGroupLabels(rfx_groups); - - // Construct random effects tracker / model / container - RandomEffectsTracker rfx_tracker = RandomEffectsTracker(rfx_groups); - MultivariateRegressionRandomEffectsModel rfx_model = MultivariateRegressionRandomEffectsModel(rfx_basis_cols, num_rfx_groups); - RandomEffectsContainer rfx_container = RandomEffectsContainer(rfx_basis_cols, num_rfx_groups); - LabelMapper label_mapper = LabelMapper(rfx_tracker.GetLabelMap()); - - // Set random effects model parameters - Eigen::VectorXd working_param_init(rfx_basis_cols); - Eigen::MatrixXd group_param_init(rfx_basis_cols, num_rfx_groups); - Eigen::MatrixXd working_param_cov_init(rfx_basis_cols, rfx_basis_cols); - Eigen::MatrixXd group_param_cov_init(rfx_basis_cols, rfx_basis_cols); - double variance_prior_shape = 1.; - double variance_prior_scale = 1.; - working_param_init << 1.; - group_param_init << 1., 1.; - working_param_cov_init << 1; - group_param_cov_init << 1; - rfx_model.SetWorkingParameter(working_param_init); - rfx_model.SetGroupParameters(group_param_init); - rfx_model.SetWorkingParameterCovariance(working_param_cov_init); - rfx_model.SetGroupParameterCovariance(group_param_cov_init); - rfx_model.SetVariancePriorShape(variance_prior_shape); - rfx_model.SetVariancePriorScale(variance_prior_scale); - } - - // Initialize an ensemble - int num_trees = 50; - bool forest_exponentiated; - if (model_type == kLogLinearVariance) { - forest_exponentiated = true; - } else { - forest_exponentiated = false; - } - // "Active" tree ensemble - TreeEnsemble active_forest = TreeEnsemble(num_trees, output_dimension, is_leaf_constant, forest_exponentiated); - // Stored forest samples - ForestContainer forest_samples = ForestContainer(num_trees, output_dimension, is_leaf_constant, forest_exponentiated); - - // Initialize a leaf model - double leaf_prior_mean = 0.; - double leaf_prior_scale = 1./num_trees; - - // Initialize forest sampling machinery - double alpha = 0.95; - double beta = 2.; - int min_samples_leaf = 1; - int max_depth = 10; - int cutpoint_grid_size = 100; - double a_rfx = 1.; - double b_rfx = 1.; - double a_leaf = 2.; - double b_leaf = 0.5; - double a_global = 0; - double b_global = 0; - double a_0 = 1.5; - double a_forest = num_trees / (a_0 * a_0) + 0.5; - double b_forest = num_trees / (a_0 * a_0); - - // Set leaf model parameters - double leaf_scale; - double leaf_scale_init = 1.; - Eigen::MatrixXd leaf_scale_matrix(omega_cols, omega_cols); - Eigen::MatrixXd leaf_scale_matrix_init(omega_cols, omega_cols); - if (omega_cols > 0) { - leaf_scale_matrix_init = Eigen::MatrixXd::Identity(omega_cols, omega_cols); - // leaf_scale_matrix_init << 1.0, 0.0, 0.0, 1.0; - leaf_scale_matrix = leaf_scale_matrix_init / num_trees; - } - - // Set global variance - double global_variance; - double global_variance_init = 1.0; - - // Set variable weights - double const_var_wt = static_cast(1. / x_cols); - std::vector variable_weights(x_cols, const_var_wt); - - // Initialize tracker and tree prior - ForestTracker tracker = ForestTracker(dataset.GetCovariates(), feature_types, num_trees, n); - TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf, max_depth); - - // Initialize variance models - GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); - LeafNodeHomoskedasticVarianceModel leaf_var_model = LeafNodeHomoskedasticVarianceModel(); - - // Initialize storage for samples of variance - std::vector global_variance_samples{}; - std::vector leaf_variance_samples{}; - - // Initialize leaf model - double init_val; - double init_val_glob; - std::vector init_vec; - if (model_type == kConstantLeafGaussian) { - init_val_glob = ComputeMeanOutcome(residual); - init_val = init_val_glob / static_cast(num_trees); - active_forest.SetLeafValue(init_val); - UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, false, std::minus()); - tracker.UpdatePredictions(&active_forest, dataset); - } else if (model_type == kUnivariateRegressionLeafGaussian) { - init_val_glob = ComputeMeanOutcome(residual); - init_val = init_val_glob / static_cast(num_trees); - active_forest.SetLeafValue(init_val); - UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, true, std::minus()); - tracker.UpdatePredictions(&active_forest, dataset); - } else if (model_type == kMultivariateRegressionLeafGaussian) { - init_val_glob = ComputeMeanOutcome(residual); - init_val = init_val_glob / static_cast(num_trees); - init_vec = std::vector(omega_cols, init_val); - active_forest.SetLeafVector(init_vec); - UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, true, std::minus()); - tracker.UpdatePredictions(&active_forest, dataset); - } else if (model_type == kLogLinearVariance) { - init_val_glob = ComputeVarianceOutcome(residual) * 0.4; - init_val = std::log(init_val_glob) / static_cast(num_trees); - active_forest.SetLeafValue(init_val); - tracker.UpdatePredictions(&active_forest, dataset); - std::vector initial_preds(n, init_val_glob); - dataset.AddVarianceWeights(initial_preds.data(), n); - } - - // Prepare the samplers - LeafModelVariant leaf_model = leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); - int num_features_subsample = x_cols; - - // Initialize vector of sweep update indices - std::vector sweep_indices(num_trees); - std::iota(sweep_indices.begin(), sweep_indices.end(), 0); - - // Run the GFR sampler - if (num_gfr > 0) { - for (int i = 0; i < num_gfr; i++) { - if (i == 0) { - global_variance = global_variance_init; - leaf_scale = leaf_scale_init; - } - else { - global_variance = global_variance_samples[i - 1]; - leaf_scale = leaf_variance_samples[i - 1]; - } - - // Sample tree ensemble - if (model_type == ModelType::kConstantLeafGaussian) { - GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, num_features_subsample, num_threads); - } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) { - GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, num_features_subsample, num_threads); - } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) { - GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, num_features_subsample, num_threads, omega_cols); - } else if (model_type == ModelType::kLogLinearVariance) { - GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, false, num_features_subsample, num_threads); - } - - if (rfx_included) { - // Sample random effects - rfx_model.SampleRandomEffects(rfx_dataset, residual, rfx_tracker, global_variance, gen); - rfx_container.AddSample(rfx_model); - } - - // Sample leaf node variance - leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(&active_forest, a_leaf, b_leaf, gen)); - - // Sample global variance - global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), a_global, b_global, gen)); - } - } - - // Run the MCMC sampler - if (num_mcmc > 0) { - for (int i = num_gfr; i < num_gfr + num_mcmc; i++) { - if (i == 0) { - global_variance = global_variance_init; - leaf_scale = leaf_scale_init; - } - else { - global_variance = global_variance_samples[i - 1]; - leaf_scale = leaf_variance_samples[i - 1]; - } - - // Sample tree ensemble - if (model_type == ModelType::kConstantLeafGaussian) { - MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true, num_threads); - } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) { - MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true, num_threads); - } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) { - MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true, num_threads, omega_cols); - } else if (model_type == ModelType::kLogLinearVariance) { - MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, false, num_threads); - } - - if (rfx_included) { - // Sample random effects - rfx_model.SampleRandomEffects(rfx_dataset, residual, rfx_tracker, global_variance, gen); - rfx_container.AddSample(rfx_model); - } - - // Sample leaf node variance - leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(&active_forest, a_leaf, b_leaf, gen)); - - // Sample global variance - global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), a_global, b_global, gen)); - } - } - - // Predict from the tree ensemble - int num_samples = num_gfr + num_mcmc; - std::vector pred_orig = forest_samples.Predict(dataset); - - if (rfx_included) { - // Predict from the random effects dataset - std::vector rfx_predictions(n * num_samples); - rfx_container.Predict(rfx_dataset, label_mapper, rfx_predictions); - } - - // Write model to a file - std::string filename = "model.json"; - forest_samples.SaveToJsonFile(filename); - - // Read and parse json from file - ForestContainer forest_samples_parsed = ForestContainer(num_trees, output_dimension, is_leaf_constant); - forest_samples_parsed.LoadFromJsonFile(filename); - - // Make sure we can predict from both the original (above) and parsed forest containers - std::vector pred_parsed = forest_samples_parsed.Predict(dataset); -} - -} // namespace StochTree - -int main(int argc, char* argv[]) { - // Unpack command line arguments - int dgp_num = std::stoi(argv[1]); - if ((dgp_num != 0) && (dgp_num != 1) && (dgp_num != 2) && (dgp_num != 3)) { - StochTree::Log::Fatal("The first command line argument must be 0, 1, 2, or 3"); - } - int model_type_int = static_cast(std::stoi(argv[2])); - if ((model_type_int != 0) && (model_type_int != 1) && (model_type_int != 2) && (model_type_int != 3)) { - StochTree::Log::Fatal("The second command line argument must be 0, 1, 2, or 3"); - } - StochTree::ModelType model_type = static_cast(model_type_int); - int rfx_int = std::stoi(argv[3]); - if ((rfx_int != 0) && (rfx_int != 1)) { - StochTree::Log::Fatal("The third command line argument must be 0 or 1"); - } - bool rfx_included = static_cast(rfx_int); - int num_gfr = std::stoi(argv[4]); - if (num_gfr < 0) { - StochTree::Log::Fatal("The fourth command line argument must be >= 0"); - } - int num_mcmc = std::stoi(argv[5]); - if (num_mcmc < 0) { - StochTree::Log::Fatal("The fifth command line argument must be >= 0"); - } - int random_seed = std::stoi(argv[6]); - if (random_seed < -1) { - StochTree::Log::Fatal("The sixth command line argument must be >= -0"); - } - std::string dataset_filename = argv[7]; - int outcome_col = std::stoi(argv[8]); - std::string covariate_cols = argv[9]; - std::string basis_cols = argv[10]; - int num_threads = std::stoi(argv[11]); - - // Run the debug program - StochTree::RunDebug(dgp_num, model_type, rfx_included, num_gfr, num_mcmc, random_seed, - dataset_filename, outcome_col, covariate_cols, basis_cols, num_threads); -} diff --git a/debug/data/heterosked_test.csv b/debug/data/heterosked_test.csv deleted file mode 100644 index 9c20150b..00000000 --- a/debug/data/heterosked_test.csv +++ /dev/null @@ -1,101 +0,0 @@ -"y","X1","X2","X3","X4","X5","X6","X7","X8","X9","X10","f_x","s_x" -0.16602939170279,0.211290733190253,0.3512832315173,0.484317823080346,0.0410519035067409,0.544329481897876,0.208616065792739,0.715793566778302,0.519131515640765,0.804753824602813,0.242159801069647,NA,0.5 --0.492151343206358,0.0310312763322145,0.467259066179395,0.217408320866525,0.585395927308127,0.310897014569491,0.555933631025255,0.999427518807352,0.609017762588337,0.833850894821808,0.831058437004685,NA,0.5 -1.20354672980893,0.324771000305191,0.243206361541525,0.00510925496928394,0.225724841468036,0.818582467501983,0.428320674458519,0.52159065939486,0.743639749241993,0.846550665097311,0.274055331945419,NA,1 --0.298743475101736,0.059549157274887,0.788293640129268,0.664353932486847,0.174846360227093,0.396595876198262,0.0166858048178256,0.671366059686989,0.164389409590513,0.748013024684042,0.213287916500121,NA,0.5 -1.03648471667601,0.252165184123442,0.103475574404001,0.159992411732674,0.836252088425681,0.74254785827361,0.88382129278034,0.559503264725208,0.0441013514064252,0.333072878420353,0.317699365783483,NA,1 -0.372412663856177,0.277434553951025,0.696199390338734,0.228776063071564,0.719009180320427,0.411765250377357,0.557771438034251,0.434479911345989,0.98785528820008,0.964887821115553,0.372693746583536,NA,1 -0.678007080201085,0.901923469034955,0.656224591424689,0.984398937784135,0.372089076321572,0.862292707199231,0.308036184636876,0.0231759585440159,0.113736179191619,0.934551725862548,0.281830958090723,NA,3 --0.552663625102306,0.384898619260639,0.653318807482719,0.57212492544204,0.617288585519418,0.212964130798355,0.596753151156008,0.907653672154993,0.364099610829726,0.8535895885434,0.874561849748716,NA,1 --0.774647129926176,0.542285286355764,0.891476916149259,0.86847648001276,0.3822966220323,0.468129127752036,0.658936342457309,0.679395364364609,0.313021682668477,0.295605404768139,0.946900382405147,NA,2 --1.44840009823072,0.786929592490196,0.990138707915321,0.00964757986366749,0.11595363356173,0.0707218167372048,0.0610561305657029,0.741362863453105,0.26028509857133,0.727263244334608,0.252514223800972,NA,3 --0.000520860259289017,0.373750334838405,0.77888114657253,0.958520723273978,0.851794941583648,0.405048871645704,0.789868315216154,0.650866032578051,0.848741357214749,0.00563913723453879,0.714518729131669,NA,1 -4.91065105854111,0.578438652446494,0.44657018291764,0.96312881493941,0.472523927222937,0.0506300190463662,0.498262552311644,0.523629083996639,0.353911799145862,0.00572359748184681,0.927425773814321,NA,2 -0.246599975172457,0.323262227233499,0.280061168596148,0.221339529380202,0.189832986332476,0.364617298590019,0.677225903840736,0.126530721783638,0.237819334724918,0.558171211974695,0.957330684643239,NA,1 -0.216574141744894,0.128740054322407,0.535212400136515,0.49389783013612,0.27021420141682,0.782948551233858,0.722482954151928,0.256591950077564,0.287929816404358,0.411982560995966,0.477760225534439,NA,0.5 -3.02695240909448,0.830361526459455,0.771890457021073,0.678807897027582,0.249619299778715,0.304002076154575,0.733004137175158,0.739543926436454,0.798186061438173,0.900906472001225,0.851402544882149,NA,3 --0.321787256332478,0.630839701509103,0.987936206627637,0.945720319403335,0.0623901186045259,0.258995248703286,0.429870260413736,0.10198374139145,0.536068871850148,0.0391868108417839,0.294132520444691,NA,2 -0.212045418087806,0.0110131918918341,0.85342196887359,0.588755084900185,0.461576499743387,0.770185991190374,0.871162265306339,0.0518779626581818,0.869023777078837,0.722436351701617,0.799294896423817,NA,0.5 --1.21242344678739,0.425265174359083,0.357040172442794,0.340407437179238,0.488938638009131,0.354673355352134,0.615913023473695,0.24169555818662,0.993121116189286,0.728206684347242,0.014574522851035,NA,1 -1.28352116284694,0.657284670509398,0.787876920541748,0.395542721729726,0.422870381968096,0.305611111223698,0.258184634381905,0.99757966119796,0.214673584559932,0.472818041918799,0.445311633404344,NA,2 --0.574621916007152,0.938520508352667,0.00738264597021043,0.339304566383362,0.935923039447516,0.826553477905691,0.751117973355576,0.412732847733423,0.446600724477321,0.908122762804851,0.410085891839117,NA,3 --1.79033131705676,0.980508934007958,0.991732611320913,0.0741828014142811,0.45592138543725,0.0406530441250652,0.153601433616132,0.680239439709112,0.495365400798619,0.136085328413174,0.0960345070343465,NA,3 -1.2730466507008,0.452979361405596,0.353644639486447,0.955131257884204,0.253656769637018,0.531206650892273,0.689106163335964,0.535432658391073,0.512008368736133,0.523187090177089,0.483651001704857,NA,1 -0.509242586001192,0.0175574035383761,0.875898850383237,0.553155447356403,0.245557092595845,0.353349522221833,0.957418726291507,0.397177098784596,0.216923247324303,0.675873225554824,0.0169080917257816,NA,0.5 --0.136460417178799,0.152752772672102,0.813950020587072,0.820145434001461,0.618891098769382,0.972209425177425,0.535218630451709,0.251730428542942,0.206172946607694,0.795590546447784,0.940857806708664,NA,0.5 -0.485453387886637,0.402201347984374,0.646702455822378,0.656405615853146,0.228795171016827,0.628267531748861,0.0990123685915023,0.514316890388727,0.170408786740154,0.966783886775374,0.67154831183143,NA,1 -0.358253559632074,0.103477319702506,0.780801502522081,0.291534556541592,0.102512540761381,0.0295835479628295,0.529529393883422,0.729500360088423,0.752762402174994,0.442406627582386,0.0135988409165293,NA,0.5 --5.08545632040851,0.502724350197241,0.428052270086482,0.527379664825276,0.26013516378589,0.645728018833324,0.877794990781695,0.123491464182734,0.907502186950296,0.479111275169998,0.0677777659147978,NA,2 -0.707859361399802,0.197307132417336,0.620997583027929,0.118650285294279,0.991431219037622,0.912463267799467,0.897971634287387,0.922440506285056,0.905837482772768,0.0830823809374124,0.424855173798278,NA,0.5 --0.288322601219351,0.0948237683624029,0.572173510212451,0.900335347745568,0.536399732576683,0.609837994910777,0.291524545988068,0.537141105625778,0.524107180535793,0.156775680137798,0.559588886797428,NA,0.5 -0.549673124655909,0.0728990600910038,0.447939761448652,0.282676422968507,0.89773990213871,0.0633204986806959,0.865510883042589,0.105334107298404,0.133278978988528,0.190158660756424,0.180829501943663,NA,0.5 -0.540571216994743,0.237544966395944,0.707299687666818,0.419169025262818,0.743102890206501,0.269150409614667,0.262001751922071,0.799080436583608,0.2612558118999,0.273263912182301,0.450734896818176,NA,0.5 --0.0452946833490895,0.226565732620656,0.955050441902131,0.150834861211479,0.852100720163435,0.115236598765478,0.736819905461743,0.699184895725921,0.0959788486361504,0.278612334979698,0.56139239971526,NA,0.5 -0.803100296173604,0.530489934142679,0.194167217472568,0.529533125227317,0.713073023594916,0.926254247082397,0.580682525411248,0.297447438817471,0.769652892835438,0.668208419112489,0.192935517756268,NA,2 -4.75098156092376,0.510896486230195,0.674823730485514,0.756135024828836,0.539776347577572,0.319426174042746,0.557334935991094,0.713212842354551,0.714800932211801,0.606322152074426,0.17348047113046,NA,2 -1.14191042699979,0.704268886707723,0.301489194622263,0.67195270746015,0.743559376103804,0.791731091681868,0.890636668074876,0.923103271983564,0.125805326737463,0.261845856206492,0.053273844299838,NA,2 --1.96289199262713,0.760442031314597,0.742438767105341,0.355483126826584,0.519423028221354,0.103875080589205,0.619771863799542,0.889494609320536,0.9085327538196,0.0071903329808265,0.535958426771685,NA,3 --1.03924818368133,0.355457389494404,0.383036406245083,0.142556034494191,0.735553824109957,0.801241792505607,0.312697742134333,0.730286944191903,0.337912452174351,0.0597852817736566,0.109000069089234,NA,1 -2.56288078808684,0.640810209326446,0.781058971770108,0.45493664429523,0.475713788997382,0.998631661524996,0.845030306605622,0.513773418031633,0.416949922917411,0.517727062571794,0.653628408210352,NA,2 -0.512191613003144,0.841465181438252,0.340397675754502,0.183199171675369,0.348972214618698,0.399342444725335,0.673026149626821,0.462349094683304,0.73421816341579,0.642916572280228,0.501591531094164,NA,3 --0.390539543957083,0.670964063378051,0.323868047911674,0.873477897839621,0.235217976849526,0.678807153599337,0.459241795353591,0.494968741200864,0.0186729081906378,0.205307064112276,0.93625286151655,NA,2 --0.387138910565391,0.505246647400782,0.206667850259691,0.212133591994643,0.223405942320824,0.534676593495533,0.622734377160668,0.088978860527277,0.732178951846436,0.555771772982553,0.74612275743857,NA,2 --0.589758275320359,0.917571729980409,0.606019907863811,0.361401373520494,0.349297178443521,0.815126905683428,0.432852742262185,0.217298388015479,0.659059998346493,0.282440681708977,0.74272277392447,NA,3 --0.535819616401154,0.195396314607933,0.922371664317325,0.635648620780557,0.0705672230105847,0.389690730255097,0.417000237386674,0.0651042077224702,0.747380811953917,0.306937695480883,0.186355540528893,NA,0.5 -0.100718265891839,0.12325372453779,0.838415022008121,0.486661186208948,0.464539987500757,0.0181915941648185,0.543029651045799,0.084053430473432,0.314679026603699,0.367974560242146,0.29371877363883,NA,0.5 -4.4639440117062,0.943375065457076,0.92623546323739,0.376004029763862,0.170030419481918,0.136762315873057,0.569733060197905,0.480285328347236,0.851396167185158,0.530285589862615,0.0130437007173896,NA,3 --1.7739095259091,0.750196340726689,0.386715405387804,0.119214843027294,0.157079570228234,0.972443665377796,0.0673526108730584,0.901384715922177,0.521999530028552,0.226870314916596,0.699825004441664,NA,3 --3.30597720310619,0.631283310241997,0.137164937565103,0.283018422778696,0.0198266028892249,0.599854060914367,0.308085889788345,0.868349778000265,0.512785997241735,0.855990704614669,0.914151601027697,NA,2 -0.573932596174342,0.0070427656173706,0.641306987032294,0.176330357091501,0.567426423309371,0.860410574125126,0.408202450722456,0.201708807609975,0.791534950956702,0.511292593320832,0.750208221375942,NA,0.5 --2.21001462414232,0.557995214127004,0.356393828522414,0.992580118589103,0.0819355831481516,0.553154400084168,0.32692315033637,0.239232912426814,0.190619267290458,0.00340557587333024,0.911083268001676,NA,2 --2.29486992099795,0.626155448844656,0.925242279656231,0.2268055845052,0.104973732959479,0.319284403230995,0.0643279056530446,0.407658671727404,0.308903530938551,0.547128311824054,0.589992973487824,NA,2 -1.20099834449345,0.465869394131005,0.820835534017533,0.980574035085738,0.109229437308386,0.939030001172796,0.209905731026083,0.141733445925638,0.5528343396727,0.369989494327456,0.328800423070788,NA,1 -2.04880751638206,0.705077335238457,0.499480672879145,0.795235471567139,0.749453478259966,0.910042788600549,0.965336692053825,0.16686162375845,0.790705517167225,0.60525494068861,0.386548299575225,NA,2 --0.510808721017406,0.0382022578269243,0.0574023951776326,0.298518662340939,0.847815460059792,0.419254497392103,0.545455661835149,0.804183976259083,0.566720734350383,0.938392051728442,0.968750852160156,NA,0.5 -0.678353506101917,0.00150989321991801,0.282608818961307,0.227338050492108,0.629864166723564,0.0941448376979679,0.269607466645539,0.0547637669369578,0.734315895475447,0.624765485292301,0.949363355059177,NA,0.5 --3.84612254244577,0.817981925327331,0.499992654426023,0.763754336163402,0.459059632383287,0.466768078505993,0.0136354095302522,0.119035738054663,0.619225281523541,0.695366266882047,0.287478444399312,NA,3 -1.18831966729165,0.689650556538254,0.954330617329106,0.0839970302768052,0.226895423140377,0.289217309327796,0.483793317107484,0.826940205646679,0.489855382591486,0.196166640147567,0.896134687354788,NA,2 --0.984385859730434,0.221693887142465,0.386208922602236,0.671697184676304,0.865381807554513,0.56949229654856,0.96111100865528,0.867747253272682,0.981907822191715,0.349981465376914,0.436939367791638,NA,0.5 --0.395896870249781,0.995180341880769,0.518618949921802,0.838818890042603,0.500900116749108,0.120013452135026,0.757572711678222,0.317005023127422,0.970106801018119,0.118986214278266,0.764332336606458,NA,3 -1.38921159150761,0.674878076650202,0.494008522713557,0.11772581236437,0.903489429038018,0.538562575588003,0.725335141411051,0.373649602057412,0.479591501876712,0.292627403745428,0.750551311299205,NA,2 -2.42945595659824,0.953342667547986,0.083741452312097,0.756313804769889,0.162194768199697,0.247160838916898,0.287883190670982,0.55613048421219,0.396303126122802,0.824216928100213,0.150460966397077,NA,3 --3.6886999739904,0.687809974187985,0.753712829900905,0.101322782225907,0.698061282979324,0.231127486098558,0.223462254740298,0.0152061164844781,0.74869236536324,0.895096591673791,0.904631495475769,NA,2 -0.536585118223136,0.688023552997038,0.401733021717519,0.964100113371387,0.407351006288081,0.605407851049677,0.754583293804899,0.348921799799427,0.613906478276476,0.991704752435908,0.878143492620438,NA,2 --0.945724263155543,0.912634621141478,0.0984702135901898,0.383788489736617,0.127180437557399,0.202219387050718,0.871567155467346,0.128467405447736,0.565469443099573,0.198427167953923,0.211701851338148,NA,3 -1.88190395261656,0.982580098556355,0.440120175015181,0.976914363214746,0.384266132023185,0.651897041592747,0.424716322217137,0.656596934422851,0.675765822641551,0.502742925658822,0.0215359264984727,NA,3 -1.75984850787277,0.818847576156259,0.283998677041382,0.752549906261265,0.436714354204014,0.99236780917272,0.125404683407396,0.771643821382895,0.72715258016251,0.0586418081074953,0.712552866665646,NA,3 -0.840920362678104,0.147506986744702,0.812251603929326,0.144564733840525,0.968569115269929,0.245198084274307,0.901848326670006,0.778830538270995,0.769952379865572,0.63684784504585,0.340954556828365,NA,0.5 --1.9371212817847,0.811514051631093,0.0480760163627565,0.165090465452522,0.285216748947278,0.520478271646425,0.14143471699208,0.530164128635079,0.171509642153978,0.663002614164725,0.605542375706136,NA,3 -2.66745669003816,0.618318668100983,0.538341464707628,0.279087009374052,0.891984875546768,0.127626759931445,0.712945127626881,0.0132724172435701,0.173096344573423,0.185944046359509,0.588332365499809,NA,2 -0.619195277945595,0.861147953663021,0.00769214262254536,0.243746276246384,0.0888408594764769,0.991699742851779,0.997524467529729,0.361437901854515,0.604364247061312,0.258257536683232,0.485892802011222,NA,3 --1.20727540221404,0.716689873952419,0.612233497900888,0.836100687040016,0.588582671945915,0.491994495503604,0.792359412880614,0.0365804079920053,0.800699540879577,0.275558543158695,0.215889111161232,NA,2 -4.72985785755328,0.932213481282815,0.372599190333858,0.526720034424216,0.653838029364124,0.497327620862052,0.48135881498456,0.0302206429187208,0.234266108367592,0.33457084489055,0.598570736125112,NA,3 -0.288958968840092,0.0119730825535953,0.300384242553264,0.850864094216377,0.893871575593948,0.680723139783368,0.990262570325285,0.865412194048986,0.379822072573006,0.718840815359727,0.890456868568435,NA,0.5 -0.259681068955783,0.237852266291156,0.0348299730103463,0.718882874818519,0.653284071711823,0.112274453276768,0.0580400477629155,0.833725305041298,0.536504363641143,0.720546880736947,0.159137749811634,NA,0.5 -0.347321451075773,0.188694522483274,0.531638014828786,0.572584339883178,0.238210330018774,0.231171588879079,0.413350867340341,0.0659690995234996,0.564861474558711,0.89177248836495,0.0945249791257083,NA,0.5 -3.89576276409494,0.833090277388692,0.203691257163882,0.146478878101334,0.819324584677815,0.196580097544938,0.892589659895748,0.226715816417709,0.877559039741755,0.85475739534013,0.0848844614811242,NA,3 -0.355581450144155,0.431055790279061,0.416096288943663,0.0859726935159415,0.325928226578981,0.58786572702229,0.350675149122253,0.401109970640391,0.756839712616056,0.0408970382995903,0.338316329289228,NA,1 -1.60495074110366,0.621179772540927,0.686551147606224,0.273364662658423,0.874324111267924,0.725664374418557,0.859822448343039,0.737721871584654,0.271499434486032,0.237752126529813,0.236145419767126,NA,2 --2.6328605154421,0.692925159586594,0.32423769752495,0.0426305204164237,0.968091825256124,0.1722840978764,0.862542052520439,0.704603992169723,0.0835464387200773,0.100951640866697,0.527263561496511,NA,2 -0.0799409275378214,0.496741143520921,0.358587516937405,0.925880230497569,0.544319620588794,0.682037363760173,0.974792276741937,0.415831092977896,0.969078256050125,0.604709276696667,0.89061022317037,NA,1 -2.72669228705894,0.478368178475648,0.310989870224148,0.928430530708283,0.247262259013951,0.795121156377718,0.861811829963699,0.400008339434862,0.144553888821974,0.947825188748538,0.19086144422181,NA,1 -0.0790671158723385,0.127485178643838,0.703777753748,0.503316113026813,0.764946941519156,0.433781018713489,0.384083716897294,0.968963841674849,0.380835567833856,0.876958164153621,0.818229214288294,NA,0.5 --0.459476378753931,0.18172625801526,0.720858050510287,0.848017353564501,0.720434938790277,0.728625974617898,0.44146662694402,0.951146713923663,0.455086235888302,0.740384874632582,0.091537274653092,NA,0.5 --0.00279715401469337,0.887286564800888,0.161548246862367,0.0689359360840172,0.555442054290324,0.846620650961995,0.00943395867943764,0.108124481979758,0.89916443801485,0.677391635021195,0.0622330722399056,NA,3 -0.249461137396448,0.0212012422271073,0.359011579537764,0.719863031757995,0.249324724078178,0.707923346664757,0.122814457165077,0.939011327456683,0.158998906379566,0.995269876671955,0.441887193359435,NA,0.5 -0.0133950863816936,0.0725361246149987,0.0616531122941524,0.287096783984452,0.382249645655975,0.883280041394755,0.615435254527256,0.487508539808914,0.597117075929418,0.629250884056091,0.159072700887918,NA,0.5 -0.418448230677183,0.0969394627027214,0.411688929889351,0.449063409119844,0.488310252549127,0.259443101240322,0.987585310125723,0.441142725525424,0.242221797816455,0.250788792734966,0.619807209586725,NA,0.5 -1.16506424422402,0.786211173050106,0.369333798764274,0.840181995881721,0.386459692381322,0.285024101380259,0.641522094374523,0.849066619062796,0.0109615188557655,0.2850232264027,0.217773179057986,NA,3 -0.989575894880979,0.539474389981478,0.796261472860351,0.47632975759916,0.676875605247915,0.890021840808913,0.499118740204722,0.972148579079658,0.982196035329252,0.0354776941239834,0.547611665446311,NA,2 --5.41773583711259,0.991268366575241,0.841971538960934,0.831874355440959,0.383187236264348,0.318007141817361,0.644914947915822,0.261283464496955,0.146094920346513,0.701359197963029,0.458953364752233,NA,3 -1.14091061117272,0.381695970892906,0.186867855954915,0.592309525003657,0.0125671157147735,0.62503357976675,0.712586513254791,0.063052834244445,0.450469782575965,0.306621422059834,0.900312934769318,NA,1 --0.0430420048530519,0.126209627138451,0.0910242712125182,0.877147340215743,0.65187450312078,0.148876765510067,0.271347487112507,0.483428512467071,0.995978597551584,0.910394222242758,0.631069793598726,NA,0.5 -2.53541834119888,0.527330596232787,0.949522181181237,0.366707332897931,0.0310849468223751,0.659584330860525,0.315074123907834,0.177570374915376,0.579150864621624,0.698416418628767,0.336142445914447,NA,2 --0.836400788818656,0.152910969685763,0.588620311347768,0.183476264355704,0.471249244175851,0.456099348841235,0.358562543056905,0.0993230841122568,0.865682073868811,0.688885818701237,0.783728409092873,NA,0.5 --0.36119486883805,0.255377688677981,0.472820749040693,0.497897547669709,0.599874146282673,0.595044590299949,0.915363078936934,0.366866510361433,0.132163925329223,0.286865012487397,0.395963127491996,NA,1 -3.04844600215742,0.784051381517202,0.334584510419518,0.889806121354923,0.633161137346178,0.861800517654046,0.5874773489777,0.395519501296803,0.458229560870677,0.794260403839871,0.584855122258887,NA,3 --1.56680469940733,0.424370453692973,0.798977045807987,0.652883867966011,0.178207511780784,0.704748782562092,0.0875129597261548,0.638247390044853,0.396533044287935,0.528938032919541,0.909660720732063,NA,1 --0.125657442776976,0.367308006854728,0.554850550601259,0.471428145188838,0.812708863057196,0.960221603047103,0.897328518796712,0.436890807002783,0.851095250342041,0.439047307474539,0.350051162531599,NA,1 --0.081334269056118,0.763460468733683,0.334580871742219,0.455008901190013,0.758000182220712,0.850450441241264,0.255375979002565,0.451483712298796,0.409017519326881,0.0273381865117699,0.59320038347505,NA,3 -0.492823092528766,0.0655732385348529,0.881846528034657,0.0530098504386842,0.35191644448787,0.175170796457678,0.296730656642467,0.907911094138399,0.603398642269894,0.255270632915199,0.642264471622184,NA,0.5 --0.000915148320753323,0.613401755923405,0.360998818883672,0.247979114530608,0.784389181062579,0.354104121448472,0.209656548919156,0.0889754833187908,0.849633313016966,0.397846080595627,0.135037919739261,NA,2 diff --git a/debug/data/heterosked_train.csv b/debug/data/heterosked_train.csv deleted file mode 100644 index 92d47304..00000000 --- a/debug/data/heterosked_train.csv +++ /dev/null @@ -1,401 +0,0 @@ -"y","X1","X2","X3","X4","X5","X6","X7","X8","X9","X10","f_x","s_x" --0.282241503700301,0.0321366493590176,0.0392713348846883,0.491026453208178,0.829869156936184,0.419834628002718,0.674110037973151,0.440613425569609,0.00938279670663178,0.688462630612776,0.246391290798783,0,0.5 -0.560156335491632,0.633398831356317,0.713975282153115,0.0746546245645732,0.183240191545337,0.366673885611817,0.195146100129932,0.543941563926637,0.770088015124202,0.815632452024147,0.204938031733036,NA,2 --2.55310551844943,0.699684502789751,0.518513803370297,0.278569060144946,0.518335236469284,0.557218188419938,0.487766429549083,0.912966759642586,0.306013585766777,0.737159843789414,0.415715463226661,NA,2 -1.2493544072818,0.376545963343233,0.365289261564612,0.878427121788263,0.236202105879784,0.150846817996353,0.263488188385963,0.295650234911591,0.752220941707492,0.503766167676076,0.0375082376413047,NA,1 --0.350237222512917,0.685897701652721,0.439451805083081,0.346272532362491,0.639235025504604,0.0299059110693634,0.791390363825485,0.818463942268863,0.805753854336217,0.0685151207726449,0.653748897369951,NA,2 -0.875396882302278,0.30443329596892,0.593776761321351,0.252662484999746,0.143341688672081,0.202486322494224,0.209879243280739,0.0234908810816705,0.443319902056828,0.865865753497928,0.605711343931034,NA,1 -0.486947118214128,0.524243678199127,0.0881543764844537,0.894031626405194,0.117429533274844,0.153696053894237,0.0199424722231925,0.0166082351934165,0.0904451431706548,0.634674217319116,0.25031628459692,NA,2 -1.52973577099548,0.478108013281599,0.41259130765684,0.922683839220554,0.537282897857949,0.28786034998484,0.518826235085726,0.768439579289407,0.712297242833301,0.471964940661564,0.0293229683302343,NA,1 -0.314146031707177,0.352395879570395,0.994232393568382,0.36731672892347,0.371404710225761,0.226430130191147,0.988522542640567,0.171963014639914,0.590013455832377,0.61276261578314,0.0982697720173746,NA,1 -1.43575656817199,0.343638661550358,0.0898715713992715,0.671082190936431,0.0814627697691321,0.115701446775347,0.404091404750943,0.914198226062581,0.77044213260524,0.724840591661632,0.960911155678332,NA,1 -0.326893619264727,0.748311516828835,0.972395642194897,0.462105612503365,0.849746079416946,0.525140055455267,0.691481846617535,0.64259241358377,0.383345428388566,0.193897266406566,0.471145344432443,NA,2 -5.33184058038735,0.805053810123354,0.694419888081029,0.688123103696853,0.857482564169914,0.412985544884577,0.756523177726194,0.587173974374309,0.698520486243069,0.306543357903138,0.569155315635726,NA,3 --0.080953281849018,0.292309782933444,0.886532769538462,0.590175394201651,0.763742916053161,0.843529825564474,0.444129067705944,0.506667435867712,0.263209822122008,0.25964065361768,0.51411981228739,NA,1 --0.0642885588731738,0.0813713425304741,0.733218841487542,0.645473898854107,0.48057269025594,0.958589600399137,0.70577060454525,0.746377976378426,0.228743602987379,0.102548423456028,0.184051728574559,NA,0.5 --2.52172005222636,0.465416571125388,0.113867231411859,0.423410250106826,0.912441357970238,0.253497969126329,0.957222544355318,0.298387421993539,0.55475458712317,0.320897468365729,0.92878125119023,NA,1 -0.26999574319132,0.0482504856772721,0.870362227084115,0.0656226673163474,0.352718068053946,0.950954237952828,0.0427663850132376,0.726099342806265,0.917875221231952,0.958915420575067,0.913584525929764,NA,0.5 --3.83785409353774,0.645924511365592,0.445315549150109,0.396382685052231,0.90340900560841,0.380887056933716,0.26203626464121,0.901702417759225,0.417292382800952,0.660986073780805,0.41848081164062,NA,2 -1.07793425501999,0.529066368937492,0.877272051293403,0.683232536539435,0.744753323029727,0.3394386311993,0.526980829657987,0.636599146993831,0.445450452622026,0.164547258988023,0.882723631570116,NA,2 -0.62946218194112,0.621317676734179,0.144061228726059,0.754505411023274,0.52779389359057,0.0482079237699509,0.308562529738992,0.983435485279188,0.138843381311744,0.889652899932116,0.731016719480976,NA,2 -0.149334178305908,0.687091745669022,0.0368512177374214,0.239253385690972,0.0125792836770415,0.357229135232046,0.194120459957048,0.189207792049274,0.275957886129618,0.757066001417115,0.0808777082711458,NA,2 -1.08233265888858,0.377081562764943,0.927687839372084,0.388060942757875,0.262930948287249,0.415721071185544,0.595551578560844,0.889767974615097,0.680972298840061,0.78021034412086,0.810993585269898,NA,1 --3.21122703337705,0.507098242873326,0.364786737831309,0.706525600515306,0.0720878560096025,0.39529197756201,0.0853938185609877,0.25313756079413,0.884323446080089,0.164221335435286,0.151563796214759,NA,2 -0.531353908252581,0.237350288545713,0.275147415464744,0.501605095108971,0.0984707665629685,0.615358766168356,0.561869938857853,0.290915884776041,0.80242047063075,0.554143948713318,0.496158995432779,NA,0.5 -0.600730293046287,0.157840367639437,0.632627204060555,0.929030198603868,0.630327652674168,0.730401052860543,0.328352578915656,0.717250395333394,0.666188669623807,0.0548528400249779,0.833407171769068,NA,0.5 -1.01314701175506,0.898484638426453,0.203661296050996,0.846307630417868,0.219352454645559,0.742492161225528,0.517635546857491,0.0817107316106558,0.968180059455335,0.336696855258197,0.662801184924319,NA,3 -0.656553622768969,0.753240135731176,0.343114467570558,0.128794493386522,0.560916883870959,0.668597492389381,0.196455833502114,0.172105104196817,0.634505996946245,0.0560522808227688,0.678417820017785,NA,3 -0.450441176587303,0.0125156014692038,0.882218054495752,0.629003981128335,0.674523104447871,0.600846989545971,0.633095325203612,0.400010914774612,0.962411692133173,0.547679667593911,0.177117998711765,NA,0.5 --3.7895427703596,0.756340250140056,0.397959120338783,0.865862431004643,0.327016743365675,0.974390563322231,0.230026092845947,0.127599307801574,0.0277679341379553,0.932881123386323,0.593616648577154,NA,3 -0.946190693629125,0.235685531515628,0.246450350852683,0.248962875455618,0.362083269981667,0.932028053328395,0.458421003306285,0.347327328752726,0.816151189152151,0.349907604278997,0.481070955749601,NA,0.5 -9.4191799850465,0.849114913959056,0.821041927672923,0.602861605351791,0.340112386737019,0.396447837818414,0.765584346372634,0.106930397683755,0.223846387118101,0.744411860825494,0.93100854405202,NA,3 -1.43916127761723,0.766780707519501,0.445421120384708,0.451271537225693,0.72373710712418,0.184640744002536,0.951642854139209,0.495319120585918,0.00020794733427465,0.887012050719932,0.271078041987494,NA,3 -0.243285053031709,0.172001409810036,0.398085105465725,0.231227887328714,0.164392935112119,0.112842592177913,0.171439406462014,0.502675085794181,0.686919268220663,0.491456889547408,0.454194214195013,NA,0.5 -0.630674862417757,0.932543548289686,0.256847345735878,0.979184590745717,0.631192493718117,0.981027684174478,0.682858986081555,0.169451557798311,0.60733424173668,0.9371962856967,0.350068996660411,NA,3 -0.601104263344641,0.218169360188767,0.321275946451351,0.484401718014851,0.851453062379733,0.309203763958067,0.917364788241684,0.2084669705946,0.245045544346794,0.582212117500603,0.228992780670524,NA,0.5 -0.840336041358887,0.589697151444852,0.685158264124766,0.327503680251539,0.213558244984597,0.763619041070342,0.11949589359574,0.71811868599616,0.408380198059604,0.832140794256702,0.770084762945771,NA,2 --0.154074442892582,0.367701322771609,0.884693963453174,0.0697798361070454,0.434232867090032,0.898273871280253,0.254723551450297,0.111902467207983,0.715196632780135,0.184001997346058,0.933242507977411,NA,1 -0.720696373819241,0.198938927613199,0.606495931511745,0.336577468551695,0.326773484470323,0.370286533143371,0.00082359416410327,0.0124902881216258,0.239021955290809,0.908346865791827,0.0976885615382344,NA,0.5 -2.21060504692689,0.682107795961201,0.287048922851682,0.372942043235525,0.396774863358587,0.819875442888588,0.710184759926051,0.190397032536566,0.480622512986884,0.8761212809477,0.223856529453769,NA,2 --0.00610575291022993,0.067679442698136,0.8684254467953,0.996917813783512,0.089049612171948,0.700067178346217,0.794930153992027,0.930825348943472,0.826459914213046,0.864694126881659,0.0812219593208283,NA,0.5 --0.932580761661321,0.572068417910486,0.0790525975171477,0.1397358069662,0.64970194036141,0.396581323817372,0.13856448000297,0.767174258362502,0.951541883638129,0.0230525995139033,0.823745897971094,NA,2 --1.23014290795405,0.450845072045922,0.671238768612966,0.580475129419938,0.165164324454963,0.766623474657536,0.279710618546233,0.353242868324742,0.870829950086772,0.200475953286514,0.171313758939505,NA,1 --3.16484165617904,0.573105162009597,0.258897884283215,0.213309426326305,0.100215684855357,0.882589143933728,0.585318948607892,0.152414533542469,0.104481044691056,0.666403246112168,0.75058483122848,NA,2 -4.62370372914815,0.73533759615384,0.740031143417582,0.788233703933656,0.918064421508461,0.720053877448663,0.39371900446713,0.0863925719168037,0.545766794355586,0.310717426007614,0.739832061110064,NA,2 -1.06384980222904,0.318392616463825,0.172780332621187,0.879689910449088,0.738148988690227,0.928580733481795,0.954701505368575,0.686478513525799,0.12954113073647,0.81541432836093,0.500270092394203,NA,1 -0.797694547807648,0.528178822714835,0.222754651214927,0.543697185115889,0.210871644085273,0.539274046663195,0.117441890062764,0.171201136894524,0.563767193350941,0.336862071650103,0.51251038024202,NA,2 --1.43612727789891,0.995858771493658,0.930734105641022,0.990805656183511,0.171419916907325,0.054182613035664,0.928726562298834,0.748863881221041,0.940693028038368,0.65918653504923,0.31154797365889,NA,3 --0.662510088784892,0.113952487008646,0.250964630628005,0.410930230980739,0.10554705071263,0.907607411500067,0.386065973900259,0.522172022610903,0.355578461894765,0.730407647090033,0.462258468614891,NA,0.5 -1.38840466048108,0.341306684771553,0.812958737369627,0.613608168903738,0.835363352671266,0.894293455639854,0.379676664015278,0.66161756683141,0.392706391168758,0.0179466558620334,0.213412882061675,NA,1 --0.00107339168469423,0.467106574913487,0.461363392416388,0.494308657478541,0.868279299233109,0.097200368065387,0.598206092603505,0.0933514924254268,0.761134962551296,0.0639440342783928,0.236494644545019,NA,1 --0.781708262436574,0.128517229110003,0.170633271569386,0.299606638494879,0.496000039856881,0.510280701331794,0.300075584789738,0.853872763924301,0.197547832271084,0.155960420612246,0.643106105038896,NA,0.5 -0.341957823006054,0.412145783891901,0.359467171831056,0.964134642854333,0.698599802330136,0.409747923957184,0.337395059876144,0.43656346690841,0.678571406751871,0.0371420315932482,0.82975935889408,NA,1 -0.056551295167318,0.222288713557646,0.49123525666073,0.395386436022818,0.539690778823569,0.574697827454656,0.397546735825017,0.730274567846209,0.652596935629845,0.617730791214854,0.38262702873908,NA,0.5 -0.106784845057808,0.148283806396648,0.0914904365781695,0.869412067811936,0.399036906659603,0.190275267930701,0.531700859311968,0.970284502254799,0.63110744045116,0.105896403081715,0.895602686330676,NA,0.5 --0.416591273016816,0.0369925673585385,0.369142951210961,0.802026466466486,0.83644736954011,0.384842669591308,0.399965344229713,0.497230796376243,0.21596611966379,0.430716055911034,0.814202995272353,NA,0.5 --3.80115746542674,0.896795936627313,0.765040824888274,0.844353931955993,0.740446796640754,0.11038355785422,0.944162652362138,0.749065604060888,0.802961661946028,0.594114093808457,0.982326869154349,NA,3 --0.990579577345106,0.364397091092542,0.75387371936813,0.92530464194715,0.556107219774276,0.472777665359899,0.290147027932107,0.558879483724013,0.734846210572869,0.912317100213841,0.382154474034905,NA,1 --2.24113101562438,0.958884046878666,0.916358025046065,0.622180263744667,0.313185492297634,0.629507611738518,0.766538005322218,0.00391310267150402,0.117573201656342,0.677621009293944,0.86312605952844,NA,3 --1.37915963138867,0.992283673025668,0.1130617109593,0.39246118767187,0.386350285029039,0.618687929585576,0.00955063896253705,0.3101056877058,0.00335579412057996,0.428965806029737,0.334215411450714,NA,3 --0.116138535683225,0.712791633559391,0.826705769635737,0.998709523584694,0.935079188318923,0.999019755981863,0.179841009201482,0.775375233031809,0.565244837198406,0.73825324466452,0.217810853617266,NA,2 -2.8651370208333,0.918227929389104,0.991976723074913,0.556448292685673,0.86033549089916,0.86803180235438,0.593546075979248,0.764600344700739,0.115971907973289,0.874116574181244,0.415115732233971,NA,3 --2.62673850900046,0.958541032392532,0.734085222240537,0.277872398728505,0.903584003215656,0.99169037095271,0.442168229259551,0.931520459009334,0.135586928809062,0.59293475211598,0.717808285029605,NA,3 -0.132124476008359,0.608879194129258,0.450298280920833,0.517520986497402,0.466313583776355,0.215587850660086,0.0298155974596739,0.189444416435435,0.174548792652786,0.0233730303589255,0.0774613090325147,NA,2 -0.293719995172287,0.0288681895472109,0.813382123131305,0.00113681610673666,0.941644994076341,0.476048385258764,0.530871283495799,0.56963656633161,0.737177671631798,0.0965919359587133,0.604818821884692,NA,0.5 --0.103954343912756,0.166011718567461,0.653191721998155,0.923596187727526,0.993117317091674,0.121775610838085,0.0122047434560955,0.826190118445083,0.30410159076564,0.38432927033864,0.433089199708775,NA,0.5 -0.716889016296623,0.424093425739557,0.161204627482221,0.926760187838227,0.688538460293785,0.412991603836417,0.0354802808724344,0.914258436998352,0.532327418681234,0.665652722353116,0.169264691183344,NA,1 --3.89856674282553,0.847578614018857,0.657475670333952,0.999632467282936,0.367575454059988,0.757724079070613,0.600674423156306,0.687886233208701,0.487299331231043,0.809120106277987,0.675410686992109,NA,3 -0.167302299143473,0.0175340240821242,0.519420479191467,0.318057127762586,0.0525879818014801,0.545170752331614,0.493110234150663,0.703221583971754,0.326910747680813,0.955182061996311,0.553046975517645,NA,0.5 -3.41642433081155,0.881483160192147,0.548339178785682,0.000273799290880561,0.615487394854426,0.331853078911081,0.753756532678381,0.619998622685671,0.887993055395782,0.321768817491829,0.723241026746109,NA,3 --0.927780454804531,0.694804448867217,0.636846353765577,0.545869589783251,0.543147100368515,0.651934214634821,0.496286484412849,0.906899358844385,0.585167640121654,0.544923613313586,0.9785473162774,NA,2 --0.119458659437126,0.349967923015356,0.0674725724384189,0.461099995067343,0.208889465779066,0.727099789306521,0.322485944256186,0.0180952423252165,0.586417717626318,0.50533275748603,0.208145300159231,NA,1 --0.388475062884792,0.442952080629766,0.550757031654939,0.0116928820498288,0.919974668882787,0.28176667355001,0.343595980666578,0.00730579695664346,0.0586100376676768,0.816800056258217,0.281768965069205,NA,1 --4.99993600969702,0.63083560904488,0.304470317903906,0.130367321893573,0.00757993618026376,0.672845395281911,0.775554867926985,0.561601209221408,0.809073857031763,0.350174822146073,0.369892556220293,NA,2 --4.0932011291792,0.857258127070963,0.849426775239408,0.048665531212464,0.663260747911409,0.201854383805767,0.374728660099208,0.680570747936144,0.791422239039093,0.155649269232526,0.159891559742391,NA,3 --1.63574784505844,0.475719389040023,0.463194821262732,0.562303752405569,0.269420249154791,0.64509441726841,0.99847622634843,0.944659149041399,0.570544230751693,0.00567547790706158,0.528426631586626,NA,1 --0.11708102669514,0.229862803127617,0.693337155273184,0.0249634576030076,0.472208397230133,0.628053227905184,0.81685835798271,0.652433575829491,0.879498174646869,0.0514549117069691,0.73972195899114,NA,0.5 -0.311440636819556,0.424796005710959,0.588375912979245,0.907698655501008,0.31491112918593,0.173823640448973,0.74140138248913,0.306885188445449,0.665830912999809,0.498456827364862,0.325599604984745,NA,1 --0.523524190351201,0.546965456567705,0.69669742975384,0.907464131480083,0.905781339621171,0.374616772402078,0.97064465098083,0.54574229917489,0.0258567882701755,0.499222912825644,0.301927668973804,NA,2 --0.218557554085927,0.545247800182551,0.532440177630633,0.0326926102861762,0.571362876798958,0.246572606964037,0.62373100570403,0.800766404950991,0.306481578154489,0.010504704201594,0.843400214333087,NA,2 -4.3394388956167,0.816038053482771,0.0158182233572006,0.43390632327646,0.118837420130149,0.119051147252321,0.583459883695468,0.119374568108469,0.339256392093375,0.430539290653542,0.673767470754683,NA,3 --3.26204680706395,0.901603788835928,0.348586000269279,0.311316560022533,0.66575604211539,0.234927158802748,0.4827365425881,0.146938317688182,0.157435120781884,0.520954007748514,0.697913567069918,NA,3 --1.34236542741488,0.620338303502649,0.691506293835118,0.481012200936675,0.556597996968776,0.0661866711452603,0.809830514946952,0.655171810416505,0.967796785058454,0.859031651634723,0.547678065020591,NA,2 -2.62955558280747,0.840353684034199,0.673164369771257,0.499901092844084,0.148192471591756,0.0805800317320973,0.836274195462465,0.529777634423226,0.809363054344431,0.569066174095497,0.855126766953617,NA,3 --0.438281983773096,0.286419975571334,0.0967769827693701,0.10907745687291,0.96515504270792,0.346270593581721,0.213748984504491,0.87659911182709,0.000996117945760489,0.11449680221267,0.900713045150042,NA,1 --0.819668273218654,0.42447182838805,0.305468514794484,0.628598166629672,0.897104371571913,0.292935826117173,0.57036331994459,0.812557882862166,0.743328155716881,0.372404619818553,0.995886287419125,NA,1 --2.15502275916223,0.581904123770073,0.136866772547364,0.686216610949486,0.0187842161394656,0.65494451741688,0.842298676259816,0.63314122450538,0.811359524028376,0.159128638915718,0.789697944186628,NA,2 --4.06652032407565,0.945780334994197,0.103755530901253,0.866593056125566,0.559091728879139,0.662870865082368,0.727699366630986,0.921519288793206,0.788001266308129,0.731000347295776,0.126275058137253,NA,3 --0.928011560521798,0.0429409199859947,0.000712319742888212,0.487844629446045,0.285959116648883,0.604428091552109,0.919549201382324,0.459383252309635,0.727056132396683,0.891209361376241,0.935548935784027,NA,0.5 --0.454744311371444,0.328576610656455,0.304207463981584,0.804804563056678,0.846173255937174,0.218100384576246,0.161341489758343,0.0678012303542346,0.141843453980982,0.453309568809345,0.174222140572965,NA,1 -2.45897650312528,0.572264407062903,0.749721936415881,0.234257691074163,0.789959133137017,0.606415096437559,0.554639648413286,0.0167610717471689,0.64993330463767,0.166575252544135,0.753427488962188,NA,2 -6.3265939587816,0.819421303458512,0.737637117039412,0.894383451668546,0.761343186022714,0.651555245975032,0.303879029816017,0.53969519212842,0.311747952597216,0.319801090052351,0.836284123826772,NA,3 --1.00774699035296,0.27150860009715,0.301227852702141,0.904794576577842,0.526734083425254,0.869649238185957,0.472009405493736,0.629169559106231,0.432649733265862,0.571082683978602,0.996163306059316,NA,1 -0.746872815646722,0.140957412077114,0.958929154090583,0.87989121559076,0.650756301125512,0.234146694419906,0.255590234417468,0.420677730115131,0.597164487466216,0.723511897493154,0.155153512023389,NA,0.5 --0.650446841065496,0.358473981032148,0.862018195679411,0.700934009160846,0.0534398728050292,0.899426093092188,0.338007411221042,0.93593748472631,0.137800792232156,0.827133719110861,0.955730211222544,NA,1 -2.48082038300247,0.961773466318846,0.0908011053688824,0.906130446586758,0.650059416424483,0.286194331245497,0.889659560285509,0.1561406950932,0.112980651902035,0.944480426842347,0.891218760749325,NA,3 --0.0986038818318802,0.0816032441798598,0.00994992698542774,0.0916588774416596,0.106353395618498,0.383079081075266,0.394785706652328,0.774634938454255,0.597865202464163,0.382828871253878,0.0871360464952886,NA,0.5 -1.1520722946698,0.372490294510499,0.642933483468369,0.932021666085348,0.25168374995701,0.656297161011025,0.551260613603517,0.00573913590051234,0.171782038640231,0.776553657837212,0.087719907052815,NA,1 -4.3995392071382,0.889362843940035,0.875090325716883,0.929815792944282,0.647907796083018,0.219110984588042,0.0118211782537401,0.674186609685421,0.940533248940483,0.127728166524321,0.00270650535821915,NA,3 -0.0290089026468435,0.248338227858767,0.142401040298864,0.763918829383329,0.76891547255218,0.181683097267523,0.87568672769703,0.820826095528901,0.332615875406191,0.810349359177053,0.414319772040471,NA,0.5 --0.0716587997708629,0.099602812435478,0.572981350123882,0.823288503335789,0.706261500250548,0.0959598675835878,0.149675080087036,0.507298650220037,0.423980995081365,0.74773308984004,0.686039772350341,NA,0.5 --1.30011461908445,0.42879819823429,0.0999327092431486,0.0399061276111752,0.891762524843216,0.942133626202121,0.152556500863284,0.170978014590219,0.368863882729784,0.0347012423444539,0.505937222158536,NA,1 --0.60014108521951,0.138419456779957,0.436613920144737,0.604420283576474,0.204740515211597,0.765600503887981,0.98930400935933,0.276836613891646,0.0905646500177681,0.362719136057422,0.529882991919294,NA,0.5 --1.21145790381126,0.744622958125547,0.749749571783468,0.360840648878366,0.762236907379702,0.297599644400179,0.55331027880311,0.624136926373467,0.798092990880832,0.851898095104843,0.404137687990442,NA,2 --1.14999787644678,0.708947046194226,0.849501037504524,0.762910400982946,0.842422707471997,0.723715083673596,0.381835736334324,0.65285807219334,0.514471322065219,0.310605535982177,0.957349325995892,NA,2 -0.819526373875359,0.515504683367908,0.370702817803249,0.823700248729438,0.387957931729034,0.642785836942494,0.348471603821963,0.0168538340367377,0.692721860483289,0.232188782887533,0.389768423512578,NA,2 -0.04335007115952,0.00294706947170198,0.814211948076263,0.31387432734482,0.261198920663446,0.432101055746898,0.918104903073981,0.415187336038798,0.976240418851376,0.291730556869879,0.519225212512538,NA,0.5 --0.941306547537896,0.320613505784422,0.445245401933789,0.659601202700287,0.373574797762558,0.309444201411679,0.99856966570951,0.845955910626799,0.385660471161827,0.863749768352136,0.85754614090547,NA,1 --1.18186963828744,0.701132035115734,0.859337297501042,0.348278685705736,0.160000698640943,0.238924831384793,0.72800212725997,0.57119212881662,0.620406114961952,0.0212759706191719,0.0539104880299419,NA,2 -0.357126829159397,0.447363865328953,0.8170770870056,0.286243395647034,0.16769379354082,0.238466019975021,0.662134238751605,0.810090478044003,0.577125230105594,0.075156477978453,0.137602924602106,NA,1 --0.190270625930595,0.114963983185589,0.593716769479215,0.903442053124309,0.815722588449717,0.452215825207531,0.0901011251844466,0.436070464085788,0.167535469401628,0.971105119213462,0.308095294749364,NA,0.5 -2.62078960911659,0.566922665573657,0.568221352994442,0.653094903565943,0.0865914954338223,0.786628634436056,0.917334056925029,0.248513758182526,0.225959765724838,0.70123139442876,0.785259994678199,NA,2 -2.55987015026694,0.833427351433784,0.499585702549666,0.598924180492759,0.138457585126162,0.320701913908124,0.774522821651772,0.431595224887133,0.257062387652695,0.22843798622489,0.532552591059357,NA,3 --6.92858680088879,0.782670683693141,0.0883395997807384,0.773956855759025,0.646301113301888,0.948011690285057,0.624721661442891,0.384072499116883,0.52077893470414,0.161269078962505,0.443590870127082,NA,3 -0.15815266439335,0.246600785059854,0.783747910521924,0.235627511749044,0.865159292006865,0.799394915346056,0.177220857003704,0.427418543258682,0.373554009012878,0.512462994782254,0.892843663692474,NA,0.5 -2.65450595933212,0.957361436681822,0.911815687082708,0.41771270474419,0.507127913879231,0.388280466897413,0.209281904157251,0.512741857906803,0.995567321544513,0.608058721991256,0.980684688081965,NA,3 --0.551655302147764,0.105588861741126,0.765772460727021,0.254358040867373,0.496851657284424,0.270303341327235,0.840589354978874,0.230695051141083,0.313359949970618,0.0843451307155192,0.604350767796859,NA,0.5 --1.8820051243768,0.865534041076899,0.282570220297202,0.518943741451949,0.582809855230153,0.733290414093062,0.464054192649201,0.0420361561700702,0.312909991247579,0.357775181764737,0.217115553095937,NA,3 -1.18644127377823,0.572014078265056,0.47516213497147,0.909601657185704,0.7345625765156,0.159067094791681,0.0758456904441118,0.596963451942429,0.248694570967928,0.609315068461001,0.460826584137976,NA,2 -2.90277751401675,0.714331415947527,0.171784817706794,0.755383554147556,0.103581401053816,0.949289575219154,0.311075175181031,0.892260275082663,0.782703787088394,0.609557135263458,0.0396083393134177,NA,2 -0.452672729165082,0.362838387256488,0.290337451035157,0.464105833088979,0.632470370503142,0.288197767920792,0.256659790175036,0.76495176483877,0.331189053365961,0.102935964474455,0.724172668065876,NA,1 --0.133181544288443,0.697857573861256,0.296673942590132,0.952276430791244,0.360607230570167,0.295894445385784,0.471734570106491,0.0396309678908437,0.965557127026841,0.30653714039363,0.577830789843574,NA,2 -0.520898443292835,0.555650745285675,0.0409085555002093,0.172368012135848,0.690776515984908,0.844665887998417,0.950760766165331,0.402099822415039,0.815108865965158,0.204499331535771,0.857367512304336,NA,2 --4.94863899675115,0.611774453427643,0.924236066872254,0.0708308285102248,0.330309199402109,0.127171179978177,0.491658851737157,0.536354664713144,0.913808302022517,0.952496240613982,0.481220846064389,NA,2 -0.736866988322533,0.438154348172247,0.806642528390512,0.0169571470469236,0.662110002711415,0.0638061012141407,0.217253629118204,0.100625155959278,0.0314080747775733,0.836882830131799,0.524368022568524,NA,1 --2.10491296672012,0.603837736183777,0.607003055978566,0.14466291340068,0.00309824291616678,0.579994248226285,0.499280680436641,0.896025418769568,0.630717721302062,0.990853895898908,0.906316698761657,NA,2 --0.385675311337004,0.0125510159414262,0.516535924980417,0.448626236524433,0.0885730378795415,0.931544049177319,0.0455953879281878,0.932046544039622,0.456221962114796,0.18970851926133,0.15640650712885,NA,0.5 --0.933954992523,0.0433406794909388,0.950225162087008,0.479722076794133,0.840589659288526,0.512415238423273,0.36178726516664,0.267573584103957,0.603488262509927,0.0370614926796407,0.624879651935771,NA,0.5 --0.520842484729798,0.0929510626010597,0.185449638403952,0.0954606782179326,0.354706885525957,0.336916906060651,0.202741828747094,0.128091655438766,0.105544385500252,0.0621682242490351,0.337343172635883,NA,0.5 --1.1902755322199,0.341341719729826,0.27227359986864,0.726273648440838,0.375441956566647,0.0984445584472269,0.24713291740045,0.68887278297916,0.395826757419854,0.263939680531621,0.85243204399012,NA,1 --0.0160453226366918,0.598223075270653,0.515979057410732,0.294184898026288,0.419611923396587,0.451983051374555,0.610102006234229,0.502045850269496,0.0474624985363334,0.312018062686548,0.97183643002063,NA,2 -0.662189390327135,0.269035245059058,0.890601610997692,0.579056739108637,0.532028702553362,0.036886066198349,0.655321523780003,0.785500477533787,0.994296758901328,0.625912741292268,0.597492320695892,NA,1 --1.47635093715271,0.519639003556222,0.0640717146452516,0.487221607938409,0.496073189657182,0.116159395547584,0.006662578554824,0.0628067965153605,0.91214673826471,0.58853417634964,0.529160888632759,NA,2 --0.54737069346927,0.0266456077806652,0.749969109194353,0.375706146936864,0.8472254879307,0.302311453036964,0.892074418487027,0.566268016584218,0.896159957861528,0.375780015485361,0.340594119857997,NA,0.5 -1.45065191789749,0.94707443495281,0.425205362495035,0.337836832040921,0.73194540082477,0.0198679927270859,0.384525762172416,0.486763422377408,0.31629200046882,0.698771668132395,0.0593036743812263,NA,3 --0.36734327686679,0.449535029008985,0.739834612002596,0.998082731151953,0.405127953272313,0.798034045379609,0.383507769554853,0.453684104373679,0.366463227430359,0.449634636053815,0.556320760864764,NA,1 --0.142193432562441,0.0993167369160801,0.780165972188115,0.76557879964821,0.643417255720124,0.0841845641843975,0.952992078615353,0.388045005733147,0.353067258605734,0.494073997018859,0.38412378448993,NA,0.5 --0.0204411745838119,0.182601721491665,0.671439114259556,0.933410462690517,0.063336655497551,0.283508115913719,0.0870876528788358,0.795035666320473,0.261529661715031,0.253830597037449,0.271484842523932,NA,0.5 --0.145046151664827,0.412740072468296,0.922483856091276,0.205560039961711,0.281724784523249,0.308287980500609,0.114223259268329,0.339805831899866,0.59704695106484,0.0895669923629612,0.20351308491081,NA,1 --0.655416133968273,0.132374021690339,0.0385842223186046,0.960168987046927,0.553003982873634,0.0313359131105244,0.998100819531828,0.0783406211994588,0.927472871961072,0.659052587812766,0.785189574118704,NA,0.5 -0.0569453360323003,0.455947425216436,0.609724464127794,0.614067499991506,0.97299245535396,0.99231677222997,0.702998753637075,0.760542156873271,0.224094831617549,0.597768754232675,0.174418788170442,NA,1 --0.407626682681292,0.138098580995575,0.752001489046961,0.289834522875026,0.654601774876937,0.41541023994796,0.00403436576016247,0.770590521162376,0.0943589757662266,0.306590826017782,0.670288912020624,NA,0.5 --1.68139582769653,0.969414204824716,0.606616077246144,0.349470247514546,0.341198356356472,0.052086751209572,0.360128439497203,0.27755997586064,0.483060038648546,0.592371104983613,0.179017000831664,NA,3 --1.41235887891515,0.484311842592433,0.627842574846,0.414641877403483,0.304095499916002,0.200165947666392,0.097875046543777,0.366788988700137,0.650650549912825,0.789857445284724,0.357476759934798,NA,1 -0.452289857944087,0.612423501210287,0.262217645766214,0.175194005016237,0.517215595347807,0.658887927187607,0.398805856704712,0.485592402284965,0.953524999320507,0.391746623674408,0.300889120902866,NA,2 --2.08435011464261,0.759304512990639,0.582466018851846,0.261647397186607,0.554752136347815,0.801065125502646,0.137913859682158,0.877758507616818,0.749789121560752,0.467832755995914,0.0966854214202613,NA,3 --0.250461882730123,0.282862489577383,0.311645941110328,0.60926196561195,0.124141810229048,0.274296503281221,0.622827796963975,0.743674198165536,0.265039020916447,0.945896445307881,0.395986490882933,NA,1 -1.38929004072203,0.581831085495651,0.604764067567885,0.0981524740345776,0.640039406251162,0.59033013205044,0.712040337501094,0.564666945487261,0.492797074839473,0.603931111749262,0.279830668587238,NA,2 -3.37308806241796,0.644188753794879,0.11642846907489,0.388798480387777,0.456940018106252,0.958898683078587,0.534385369159281,0.026345114922151,0.103369149379432,0.836727291345596,0.575546889333054,NA,2 --1.80625513593003,0.665181114571169,0.865635313559324,0.236588786588982,0.625312402611598,0.197498777182773,0.469960128655657,0.853968983516097,0.0933152246288955,0.272778454935178,0.137419609585777,NA,2 --0.598181630839518,0.20983607089147,0.388797290390357,0.747019624104723,0.93496206542477,0.0514743681997061,0.420886440202594,0.680075784912333,0.825326177757233,0.238552219001576,0.866760726552457,NA,0.5 --0.678701228012741,0.367456012638286,0.76701890816912,0.299027033150196,0.638889130437747,0.973676205612719,0.202693411381915,0.047896116040647,0.693167041055858,0.72095926781185,0.85648188367486,NA,1 --0.439448370591987,0.171001069946215,0.7493276654277,0.766403131186962,0.142307656118646,0.68866288033314,0.379330020165071,0.842053881846368,0.453404651256278,0.344623880228028,0.187736973864958,NA,0.5 --0.429402457336857,0.232082100119442,0.912413143087178,0.417194787180051,0.263550848001614,0.0323172744829208,0.70294936420396,0.237633729353547,0.179497604025528,0.155822693370283,0.655707615427673,NA,0.5 -0.260273411620581,0.397116974228993,0.285043719923124,0.626160256098956,0.440539639210328,0.228579026414081,0.950656231259927,0.738680086331442,0.909562820801511,0.836534011876211,0.909986349521205,NA,1 --0.601741747228831,0.00319111370481551,0.801924289902672,0.936224457342178,0.0328773867804557,0.680261990521103,0.817012073937804,0.107413252815604,0.0303064233157784,0.471293790731579,0.67827691603452,NA,0.5 -0.579175363038437,0.382363738259301,0.393253202782944,0.158311059698462,0.802881805924699,0.164180209161714,0.210846305359155,0.0777299583423883,0.271905378438532,0.415683274855837,0.646119900047779,NA,1 -0.233684619832068,0.104158732341602,0.044687672983855,0.755966236116365,0.898829346057028,0.747886597644538,0.785850781947374,0.187898478936404,0.26085401698947,0.919391671195626,0.723531056428328,NA,0.5 --1.18176505505841,0.748055958654732,0.457851321436465,0.75289491051808,0.874799456913024,0.5951328065712,0.474716911558062,0.540802971692756,0.586096824146807,0.983661994105205,0.879120596917346,NA,2 --0.899714329525814,0.598819239530712,0.214800329878926,0.255403825780377,0.657175703207031,0.83159053674899,0.0539117860607803,0.840320016723126,0.30811989097856,0.841044461587444,0.933701136615127,NA,2 -0.646093632296236,0.36379097471945,0.930275573628023,0.413587759248912,0.677580804331228,0.780740413581952,0.0646629584953189,0.859191400231794,0.603091831551865,0.680502528557554,0.776847406988963,NA,1 -1.02643265104361,0.787000266136602,0.130031454842538,0.812010727357119,0.11337968474254,0.226478895405307,0.9479467747733,0.289414885919541,0.826333333505318,0.482536654453725,0.204640087904409,NA,3 --1.60873401630344,0.333655280526727,0.473265617853031,0.913793664192781,0.927346032112837,0.598311467794701,0.802216146606952,0.604303786531091,0.204727914649993,0.256905679125339,0.353850594256073,NA,1 --0.193922357216301,0.309134368319064,0.69424194470048,0.045563388383016,0.0289882610086352,0.702026945538819,0.353906127158552,0.578083222731948,0.998108556028455,0.706539734266698,0.111041986150667,NA,1 -2.25822378003256,0.699067022651434,0.175160866929218,0.675344614312053,0.653209868585691,0.448388746008277,0.639876928878948,0.529742702143267,0.386671084910631,0.871713967528194,0.597260946873575,NA,2 --0.294217127054907,0.0130114420317113,0.745208188192919,0.662316924892366,0.869964884594083,0.346366792684421,0.134350478416309,0.91824146383442,0.628984485287219,0.677548446692526,0.685797044774517,NA,0.5 --5.09182856615671,0.535232451977208,0.519287661882117,0.85244823875837,0.555158670060337,0.251911287195981,0.740879685385153,0.578002710128203,0.67449068557471,0.13997372216545,0.171553277410567,NA,2 --0.957106599062254,0.260365868685767,0.039272784255445,0.205235176021233,0.658755406038836,0.0693191748578101,0.265798757784069,0.553729085018858,0.56509280600585,0.896149475127459,0.0161577914841473,NA,1 --0.74547676965435,0.216713041067123,0.312677611829713,0.271783971460536,0.31258171889931,0.547919424949214,0.143913608510047,0.106598092708737,0.155940489610657,0.147696991218254,0.264999630395323,NA,0.5 --1.7556278223873,0.301559766521677,0.813955902820453,0.378281063167378,0.23688638349995,0.825144364265725,0.807417817413807,0.154609765624627,0.987439922289923,0.27786693489179,0.650420303922147,NA,1 --0.0470943359654031,0.160312748048455,0.732178756734356,0.151784234680235,0.372392004355788,0.374390203272924,0.544937105616555,0.617528083734214,0.677572895074263,0.368166692554951,0.655130254570395,NA,0.5 -0.124658668723908,0.182294715894386,0.995739555219188,0.113104562275112,0.11240058299154,0.208648389670998,0.537716497201473,0.0649266852997243,0.650002721697092,0.0501199315767735,0.357121436158195,NA,0.5 -5.21455443964906,0.779459196841344,0.449038189370185,0.292015180690214,0.143725096713752,0.904136493103579,0.608930496964604,0.354715100722387,0.194870956009254,0.843750018393621,0.0934222042560577,NA,3 -1.75772733370199,0.821147922892123,0.425148322479799,0.171820772113279,0.235684452112764,0.761053857393563,0.315450295107439,0.220257700188085,0.835394476307556,0.984954028623179,0.115882620681077,NA,3 --2.47988607885718,0.629339757142588,0.828968267887831,0.995918438304216,0.607340537942946,0.0347810362000018,0.946215031202883,0.967786622000858,0.126801078440621,0.765530567616224,0.121757410932332,NA,2 --0.894667690433518,0.781423395499587,0.0622167293913662,0.272649483056739,0.198446429567412,0.659495470346883,0.81549444841221,0.649437340209261,0.403396281879395,0.0653008993249387,0.449648660607636,NA,3 --2.59967794033893,0.833517065504566,0.922720778966323,0.283613242208958,0.117495934013277,0.44706165464595,0.493439532117918,0.653966138372198,0.117235817248002,0.946330566890538,0.346774617442861,NA,3 -0.562866817007894,0.066422026604414,0.0124633046798408,0.348506391048431,0.520839661592618,0.444712787633762,0.71667861030437,0.328789311693981,0.026976422406733,0.644805737538263,0.929708110401407,NA,0.5 --0.98353176943091,0.345342023530975,0.675448514288291,0.863113478990272,0.767568795010448,0.745967566501349,0.583446231437847,0.355865387944505,0.183258432662115,0.327033995883539,0.742279504658654,NA,1 -1.39982507996358,0.301927540684119,0.425509258871898,0.963772800983861,0.432708555832505,0.552789221983403,0.862505354220048,0.194295163266361,0.607898182235658,0.666282393038273,0.855093700578436,NA,1 -2.19108366345933,0.784785338677466,0.910598614718765,0.0310040663462132,0.792512105545029,0.913687644060701,0.779540273128077,0.203273280523717,0.580115466378629,0.778640141477808,0.0813250311184675,NA,3 -0.102754799970055,0.0151381115429103,0.303750258870423,0.530344953527674,0.398493434069678,0.551140069263056,0.331004866398871,0.0359833822585642,0.954829844180495,0.751596683170646,0.923089451156557,NA,0.5 --2.71593227987405,0.975154257612303,0.158668037503958,0.417390400543809,0.100585088133812,0.454939389135689,0.426590847782791,0.1122667635791,0.172702635638416,0.360355676151812,0.969007642241195,NA,3 -0.11316914619458,0.144407300045714,0.726451003458351,0.323942760471255,0.0605967913288623,0.890680358512327,0.902225363301113,0.955268873134628,0.0258821186143905,0.516488530673087,0.00521893566474319,NA,0.5 -0.214612463434781,0.382541353581473,0.541727594798431,0.679740986786783,0.000293354038149118,0.129221958573908,0.0420920078177005,0.329942629439756,0.341416005743667,0.02168780984357,0.247941843234003,NA,1 --3.25744451413655,0.878030564635992,0.612484061392024,0.841921371873468,0.389507194282487,0.873570098308846,0.941481810528785,0.270786457462236,0.105030868668109,0.0905804259236902,0.677928207209334,NA,3 --0.612778977564602,0.299138211412355,0.71343940589577,0.362453953828663,0.569098226493225,0.806055163498968,0.984774222364649,0.0434266540687531,0.158101887907833,0.446324746357277,0.10046491608955,NA,1 -0.768254629890036,0.090535826748237,0.598878909833729,0.383121004793793,0.942228753585368,0.728780089877546,0.816637960728258,0.63866300880909,0.776843504980206,0.295215973397717,0.961366922827438,NA,0.5 --0.567603909448942,0.0424070812296122,0.419121650280431,0.861307024024427,0.638702552299947,0.642014807322994,0.769632588373497,0.153189419070259,0.810209463350475,0.170681128744036,0.77459627809003,NA,0.5 --5.9771524195317,0.862377436831594,0.652252439875156,0.904387121321633,0.797982558840886,0.350751545047387,0.78046753979288,0.985694216098636,0.70478395326063,0.913390251807868,0.0379480242263526,NA,3 -3.06750410813849,0.80633880593814,0.679136122344062,0.952588496729732,0.829646679805592,0.881833428051323,0.884054609807208,0.0632740694563836,0.70828592334874,0.175746616674587,0.878892200067639,NA,3 -0.957806121739846,0.112819531699643,0.528766673523933,0.779676516773179,0.985083922278136,0.22144174692221,0.14839206263423,0.560842191567644,0.211447972571477,0.751184774329886,0.262189727742225,NA,0.5 -0.7394269431204,0.0613693608902395,0.0450991443358362,0.469146619085222,0.236004963051528,0.938807192258537,0.464330278337002,0.197913277195767,0.454433776205406,0.08160714013502,0.868174660252407,NA,0.5 -2.89437745658501,0.765211458783597,0.946839487645775,0.702625217847526,0.89387573953718,0.322030391078442,0.635835886467248,0.490745102521032,0.251796741737053,0.57421529921703,0.919654391240329,NA,3 --0.837698154337176,0.911105049075559,0.887382954359055,0.218226345721632,0.503832868300378,0.00183678534813225,0.586421895073727,0.72927557816729,0.53785819420591,0.782691985368729,0.0610043394844979,NA,3 -1.61332292935373,0.767474424093962,0.657503654249012,0.23821485415101,0.3197404849343,0.848195793339983,0.523012418299913,0.220011793542653,0.923964306944981,0.419060569256544,0.20017861854285,NA,3 --0.111281104502172,0.136224425164983,0.408692202530801,0.915172693319619,0.946666706586257,0.544694220647216,0.307216387009248,0.844691653503105,0.402090376475826,0.111828369786963,0.364809120539576,NA,0.5 --0.743268442892708,0.631344011286274,0.416102928575128,0.94413023837842,0.405237703584135,0.0686362243723124,0.856442781630903,0.569145374232903,0.298357689753175,0.543654523789883,0.354759010253474,NA,2 --0.905598862688402,0.429458065889776,0.599177881376818,0.525696570519358,0.556089870398864,0.374567771330476,0.73522638506256,0.165942407678813,0.241547536570579,0.640308366389945,0.910148287890479,NA,1 --0.381371694003123,0.319485255284235,0.775019305292517,0.202426716685295,0.393176882760599,0.383716176263988,0.833298979094252,0.63092312333174,0.456514721736312,0.263742276001722,0.599667183123529,NA,1 --0.423702137252754,0.11936202365905,0.114415199263021,0.187402412761003,0.825105560244992,0.401661385083571,0.598708146018907,0.626203504158184,0.0329728100914508,0.0982055109925568,0.259948245948181,NA,0.5 --0.805397857933071,0.126862512435764,0.561627364018932,0.226535556837916,0.0691305936779827,0.0460475136060268,0.0944148511625826,0.584244670113549,0.75893074576743,0.932498198468238,0.695256362203509,NA,0.5 -0.319155083005407,0.472723427927122,0.787534170318395,0.645320721901953,0.715023176278919,0.209356067935005,0.81911487551406,0.801698182709515,0.70181589666754,0.698942579096183,0.469195405719802,NA,1 --0.0907648903587075,0.463830034947023,0.883693974465132,0.996114502195269,0.933144294423983,0.130785294808447,0.672029396053404,0.39099053083919,0.82090491685085,0.889978097518906,0.481951746391132,NA,1 -0.685509547256731,0.733466965379193,0.0820799909997731,0.0614028854761273,0.193428620696068,0.842947125434875,0.147228190675378,0.63629440870136,0.00185011373832822,0.664431531913579,0.359879235271364,NA,2 --6.88310219990682,0.766328622121364,0.514705869369209,0.0641286331228912,0.71263275318779,0.865839602192864,0.317526940023527,0.472104668151587,0.0943877408280969,0.525016574189067,0.392749238293618,NA,3 --2.30749559493075,0.317944551585242,0.313390593510121,0.785861490992829,0.626033872598782,0.691302561433986,0.909780625021085,0.636675088433549,0.000775783322751522,0.607812081230804,0.350837268168107,NA,1 -0.190249220215772,0.0292525580152869,0.784951515262946,0.726968697272241,0.515227085910738,0.952373256674036,0.985080229351297,0.0235045454464853,0.00416592671535909,0.967333158710971,0.396277377847582,NA,0.5 --1.55456092344735,0.298712126677856,0.329398845555261,0.824424406746402,0.504606149625033,0.462511173449457,0.987116791075096,0.984700593166053,0.738257547840476,0.506412111455575,0.215016127098352,NA,1 --5.71298108096891,0.863856402924284,0.472625948023051,0.0874699417036027,0.909145939163864,0.37907059583813,0.480933768674731,0.478833106346428,0.669854396488518,0.0413195120636374,0.272561500314623,NA,3 -2.72257293704493,0.977674740599468,0.757155697094277,0.207756455754861,0.203775214497,0.78417210560292,0.582602435722947,0.666409130906686,0.686884060502052,0.110600611427799,0.558197030797601,NA,3 -10.2974933407333,0.817447017878294,0.907617374090478,0.650747747858986,0.467643693089485,0.37755541363731,8.57806298881769e-05,0.00122400932013988,0.270136442501098,0.128451892407611,0.193024599459022,NA,3 --0.515443771934163,0.156245535472408,0.951601629611105,0.789141749264672,0.919310998637229,0.852464217459783,0.700894636334851,0.253609915263951,0.141912077786401,0.594247984467074,0.577542662387714,NA,0.5 -3.60755387259999,0.935209455899894,0.168847924564034,0.71698178560473,0.164761062944308,0.820564121939242,0.269563676323742,0.865492029348388,0.0663771124090999,0.757397901965305,0.30812180461362,NA,3 -1.53424165667702,0.495078047504649,0.00743950600735843,0.699956620577723,0.150095566408709,0.820907787652686,0.0250667075160891,0.219115395331755,0.0701660695485771,0.144342829938978,0.661241426831111,NA,1 -2.11945690269233,0.494878040393814,0.335358702111989,0.0859456097241491,0.334762825863436,0.436428788118064,0.871063163038343,0.862540878355503,0.149437770945951,0.172741169808432,0.301684214035049,NA,1 --0.764114777927609,0.472692873096094,0.913892224663869,0.814544472144917,0.345459071453661,0.473493713885546,0.772064296295866,0.964068693341687,0.302707911469042,0.310749824624509,0.326807718491182,NA,1 -1.59341061109396,0.607975764898583,0.0794732617214322,0.974691078998148,0.617390216095373,0.461572230560705,0.722619395470247,0.609765309141949,0.594569276785478,0.192564327269793,0.0523809674195945,NA,2 --1.05008884817102,0.458628129446879,0.276793640572578,0.394021584419534,0.216551521327347,0.656353653175756,0.901549807749689,0.893069363664836,0.394100816454738,0.276124453404918,0.0872335392050445,NA,1 --0.454916415458206,0.100318351062015,0.378878374584019,0.0983213325962424,0.174892703304067,0.711960697313771,0.919599503278732,0.190358696039766,0.269345039268956,0.287819736870006,0.287268604617566,NA,0.5 --0.485510412623777,0.241354882251471,0.949344601249322,0.158448557602242,0.133629100397229,0.866928705945611,0.937315163202584,0.184269865509123,0.327862715348601,0.563999371370301,0.785677004838362,NA,0.5 -0.776775239291567,0.291910991771147,0.0846551819704473,0.77373803849332,0.182577745057642,0.20777322165668,0.364862532587722,0.302542981225997,0.99205335858278,0.101478747557849,0.831732955761254,NA,1 --0.503915118460329,0.0921388799324632,0.69793195463717,0.640171206556261,0.433864553691819,0.868978491751477,0.679959014523774,0.123384471517056,0.638639747863635,0.164615281391889,0.954158689128235,NA,0.5 --2.45868231246521,0.486175883794203,0.204146457370371,0.274052948923782,0.427365210838616,0.0658067418262362,0.238397930283099,0.715437222272158,0.305063903564587,0.979477859567851,0.535615005297586,NA,1 -3.07267136331703,0.672265029046685,0.683067857287824,0.300874865613878,0.990141049260274,0.336872167419642,0.262356182327494,0.166166434762999,0.507457473315299,0.182093014242128,0.988365483703092,NA,2 --1.35725074282868,0.373770251171663,0.989214881556109,0.348896064329892,0.00952328857965767,0.870421917643398,0.0278489568736404,0.971605542581528,0.0709339645691216,0.96520364494063,0.830430411733687,NA,1 --2.09439816462769,0.6810261097271,0.785620857728645,0.798142119776458,0.381587166339159,0.372327584540471,0.625061142724007,0.36749339546077,0.582425137981772,0.641279636882246,0.120793332578614,NA,2 --0.00796859068675282,0.0229878642130643,0.855209650471807,0.982022156473249,0.256182591430843,0.777578515000641,0.825847221305594,0.558763438137248,0.683600728865713,0.737240033224225,0.413099535508081,NA,0.5 --1.4979445469424,0.714179781964049,0.706974609289318,0.173396941041574,0.230190037516877,0.856905392371118,0.897651092847809,0.326673374744132,0.499662009300664,0.31590295699425,0.814334665657952,NA,2 --0.439028039829894,0.69282753020525,0.235939425881952,0.237423748476431,0.719278433360159,0.29460807191208,0.928081472637132,0.536010401323438,0.794380184262991,0.658699807012454,0.520766410976648,NA,2 --0.242951793769155,0.243101859698072,0.307146482169628,0.843962025130168,0.0316330003552139,0.430362788261846,0.351072994992137,0.654738643905148,0.969542238861322,0.60179814370349,0.398643856402487,NA,0.5 --0.630263115026057,0.200559745077044,0.907688920153305,0.334342788904905,0.150655995355919,0.972313160542399,0.542266771197319,0.50963302818127,0.218434902839363,0.748211001046002,0.228817308787256,NA,0.5 -0.696031295956962,0.328547941055149,0.876553914509714,0.164135021390393,0.395798878511414,0.477207755902782,0.283287385711446,0.191637014970183,0.346791804069653,0.0879971194081008,0.774699172470719,NA,1 -1.86218635994138,0.667642474640161,0.175826355349272,0.104500935412943,0.188644989626482,0.935310949804261,0.534588502487168,0.650544292991981,0.280430254060775,0.881018078653142,0.322163599310443,NA,2 -0.0618266822267202,0.966812142170966,0.964485211065039,0.745329377939925,0.468591102631763,0.238328024744987,0.248983833007514,0.304753655567765,0.806436447193846,0.374863224104047,0.268298006616533,NA,3 --0.213469050164475,0.626246140571311,0.371381344739348,0.64032802847214,0.563473800662905,0.100175821688026,0.0235935947857797,0.410097931977361,0.984508452471346,0.908335901331156,0.106187571771443,NA,2 -2.71820413597084,0.342010788153857,0.705079389037564,0.551346761640161,0.598823779728264,0.448870678432286,0.737467909697443,0.322141070617363,0.551015156554058,0.226791883818805,0.913258621236309,NA,1 --1.01015670294422,0.765705335186794,0.826494007604197,0.231260113185272,0.565515470458195,0.248132557841018,0.171093754237518,0.286525617586449,0.441912526264787,0.227831450290978,0.286231678677723,NA,3 --3.14339298170896,0.951644803164527,0.232232666807249,0.908373491605744,0.913206648547202,0.407022410538048,0.953037474770099,0.274196953279898,0.27726714592427,0.842141730245203,0.17599373916164,NA,3 -0.816940374582757,0.00308033730834723,0.223011573543772,0.0552042822819203,0.646219877526164,0.89388394379057,0.35093015129678,0.0684547659475356,0.56139280856587,0.433695368934423,0.780848235357553,NA,0.5 --1.06114260561244,0.488627271028236,0.343027671799064,0.495272737927735,0.920322231249884,0.0372702074237168,0.504528037970886,0.391261521494016,0.130302850389853,0.576689033070579,0.0055089273955673,NA,1 -3.14001425018611,0.936363400658593,0.68013828061521,0.900950629264116,0.699788666563109,0.58299405220896,0.566819867584854,0.539389610989019,0.598359101684764,0.961430250899866,0.322504172567278,NA,3 -3.75658460367109,0.72642320394516,0.340536155505106,0.644040668383241,0.0150993110146374,0.76843476947397,0.496726306155324,0.994598030112684,0.0396893250290304,0.878083458403125,0.48237074050121,NA,2 -1.12695583768729,0.983749621314928,0.788466003490612,0.593351603019983,0.94108819635585,0.306970248464495,0.201373670483008,0.209501806180924,0.305320084095001,0.451809530379251,0.385713089257479,NA,3 -1.52459560077634,0.672651072964072,0.46188158262521,0.209960805252194,0.545736834639683,0.21016333415173,0.909396487753838,0.140858779894188,0.656669340096414,0.829055720707402,0.514228889951482,NA,2 -0.963944520418568,0.0890714421402663,0.796551150269806,0.0773013802245259,0.993034319253638,0.18855186062865,0.839895894518122,0.13716762047261,0.883261514361948,0.305906966794282,0.866397253004834,NA,0.5 -0.129771290079079,0.189816636266187,0.434954048832878,0.614785549463704,0.0503757097758353,0.514615601161495,0.952145036077127,0.637101435568184,0.426087585045025,0.311506387079135,0.234520558966324,NA,0.5 -0.912884979700598,0.683184877503663,0.510933240642771,0.473237108206376,0.597530248807743,0.262252091662958,0.373021523468196,0.595309577649459,0.412889627274126,0.824333656812087,0.681659359484911,NA,2 --2.09362850541451,0.426857191603631,0.896169882267714,0.314928279956803,0.495635512284935,0.482654591090977,0.713752283481881,0.620617669541389,0.778110204497352,0.956990454345942,0.608790414175019,NA,1 --0.0644100579741575,0.721381928073242,0.246724736411124,0.742296724347398,0.949994837632403,0.286190049722791,0.0760854352265596,0.445758877322078,0.949769111117348,0.925280713476241,0.499942892463878,NA,2 -0.10139334086343,0.0579308001324534,0.994543595705181,0.293401257600635,0.00884762965142727,0.889348507625982,0.549911052221432,0.615719662979245,0.280096381437033,0.143207954941317,0.271301218308508,NA,0.5 -2.06653528604937,0.518336966400966,0.506918795639649,0.466876366641372,0.620179544668645,0.486438295803964,0.383961294777691,0.519759878516197,0.0823564794845879,0.572632172610611,0.927550779422745,NA,2 -0.083993183366566,0.792077120160684,0.0227267004083842,0.350600538309664,0.802795631811023,0.271761489799246,0.752974940929562,0.58151887729764,0.448455915320665,0.114585774950683,0.480480066733435,NA,3 -0.871280979378034,0.155268164584413,0.169720521196723,0.209214897826314,0.969444170594215,0.455098179168999,0.162548776483163,0.30731915589422,0.797547695925459,0.937418807530776,0.699685209430754,NA,0.5 --0.993675609227815,0.31696749618277,0.255061091855168,0.329014195129275,0.311154339462519,0.901402737712488,0.881375814089552,0.592734534991905,0.0463438257575035,0.407562130596489,0.94760683667846,NA,1 -3.87620571131626,0.712310006609187,0.150037638610229,0.24496771581471,0.652770188869908,0.754434264730662,0.353124728426337,0.437375633278862,0.317781692370772,0.994260160252452,0.886606745189056,NA,2 --0.184073402751879,0.302567538106814,0.236578963929787,0.529907553223893,0.0782564121764153,0.575638748938218,0.953494219575077,0.74427487119101,0.285246771760285,0.675243055913597,0.700421273009852,NA,1 --0.82056052576034,0.607995933154598,0.269555284408852,0.859022193821147,0.785452287411317,0.556223073042929,0.531818702351302,0.694545754231513,0.475929024163634,0.182002258952707,0.0817325403913856,NA,2 --0.128441943850425,0.689589613582939,0.838425317080691,0.0708509180694818,0.0456330950837582,0.244261574465781,0.352208221564069,0.257190010044724,0.259088592603803,0.793637574184686,0.251339136157185,NA,2 --1.67953017850423,0.513718589907512,0.177739400183782,0.135568092111498,0.754302552901208,0.119685888523236,0.271099909441546,0.881469729356468,0.380955666769296,0.0594409927725792,0.285897349473089,NA,2 -4.01029674994469,0.637662801891565,0.681524514220655,0.0506364298053086,0.195933144073933,0.177333432948217,0.206886797677726,0.685570570174605,0.356021656654775,0.985714256530628,0.845266965450719,NA,2 -0.418247462976808,0.156605343334377,0.341756730806082,0.963353005703539,0.899666843237355,0.4310629342217,0.946215359028429,0.897696364205331,0.776364185847342,0.420296776574105,0.381828723009676,NA,0.5 --1.51268669424462,0.62234056647867,0.577063782373443,0.603003362193704,0.446018574293703,0.340106446295977,0.65706201409921,0.152433987008408,0.260849780868739,0.550666619790718,0.255014094058424,NA,2 -4.89636360756605,0.807543221861124,0.476624209433794,0.202747098170221,0.38368299161084,0.331279368372634,0.282060217810795,0.374299304792657,0.750059184152633,0.606916478602216,0.0227296042721719,NA,3 -2.17393837350942,0.280177400680259,0.634965054458007,0.408001474570483,0.916558355325833,0.714338184334338,0.583447890123352,0.470874862978235,0.0685866239946336,0.717115951701999,0.843359079910442,NA,1 -0.114333681805783,0.438450475456193,0.602708112215623,0.349823107244447,0.414686303352937,0.198583260178566,0.261792612727731,0.163943613413721,0.180993436602876,0.669274596963078,0.0332412673160434,NA,1 --0.0575913747610489,0.180707123363391,0.526421991642565,0.82064001262188,0.521263679722324,0.792229591170326,0.514831050531939,0.655353802023456,0.148887025890872,0.567784914281219,0.786496615968645,NA,0.5 --1.01172391605431,0.463357580360025,0.267548341304064,0.839970801956952,0.421053394908085,0.835206566611305,0.946785928681493,0.722578325076029,0.913413252448663,0.466875985031947,0.519509368343279,NA,1 --1.01622844711912,0.636612149421126,0.569439431419596,0.0444083481561393,0.166187049355358,0.890496745705605,0.972773446468636,0.512027263408527,0.453198195435107,0.185353330336511,0.822659314842895,NA,2 -0.879633812737811,0.937646907521412,0.357851214939728,0.754458939423785,0.757599327713251,0.252516763517633,0.253876760136336,0.213992986362427,0.89202951756306,0.274374803295359,0.646439817501232,NA,3 -3.17475900901958,0.715058491099626,0.757979915942997,0.363833147101104,0.339705190155655,0.342578890034929,0.504253518534824,0.836292343214154,0.440996023127809,0.0837488134857267,0.572989572072402,NA,2 --0.0127358680059262,0.91581250471063,0.954323386540636,0.430059639271349,0.762585739139467,0.368888254277408,0.20797983976081,0.47765283472836,0.167148848297074,0.308806973975152,0.314454952254891,NA,3 --0.8544312017705,0.502859967993572,0.151654090732336,0.396629377966747,0.021326761925593,0.205127923749387,0.539602338802069,0.0801719883456826,0.0248979306779802,0.718125011539087,0.321455252822489,NA,2 --1.82887098666628,0.262003114912659,0.0757940574549139,0.85448711225763,0.564622745616361,0.0720540515612811,0.967587752500549,0.959227720042691,0.638316398952156,0.562585303094238,0.834035288542509,NA,1 -3.2622467572984,0.335096222581342,0.244150118669495,0.183940630406141,0.615293832495809,0.827252987539396,0.267028478439897,0.273504600860178,0.507153154583648,0.388217261061072,0.138010729802772,NA,1 -1.65601016742667,0.838172710733488,0.954755051294342,0.608526331372559,0.836548767983913,0.919543279102072,0.515458332141861,0.598611275665462,0.66656892397441,0.607837118208408,0.493392545264214,NA,3 -0.381933395917132,0.261543265311047,0.95214184653014,0.764482910744846,0.966228008968756,0.493756796000525,0.872538808966056,0.585356997791678,0.242046943400055,0.900154284201562,0.452932744985446,NA,1 -0.061813986449167,0.5902709166985,0.654293139465153,0.18226848798804,0.914126749383286,0.398949036840349,0.773602006956935,0.743641247740015,0.208624595077708,0.4145590539556,0.676244510803372,NA,2 -0.274146489910667,0.560014321003109,0.350247922120616,0.756335157435387,0.416073066880926,0.0586489790584892,0.301071509020403,0.422949285479262,0.210499112727121,0.482112698722631,0.112424957100302,NA,2 --5.75883040392064,0.901148825651035,0.810049108229578,0.878282216377556,0.316235190955922,0.230603245552629,0.250381491146982,0.650639815721661,0.0529408203437924,0.958216429455206,0.889385785441846,NA,3 -0.911749288623184,0.539145407034084,0.596371402265504,0.897788940230384,0.439045100938529,0.16618771571666,0.142886094050482,0.667914129793644,0.535841233795509,0.753273370675743,0.0447190594859421,NA,2 -4.55360390370342,0.945572390919551,0.0924056400544941,0.939797199331224,0.642557488754392,0.0861750931944698,0.632535864366218,0.465771202929318,0.272886190097779,0.772984977345914,0.715489778900519,NA,3 --0.907822599380484,0.423604313516989,0.944986613234505,0.159030285896733,0.922660069772974,0.964939738158137,0.532024955842644,0.634939571144059,0.39645018754527,0.635082328226417,0.593678319593892,NA,1 --0.139344587678803,0.58962686243467,0.590537373209372,0.14912519371137,0.20451657124795,0.0897868368774652,0.738680734764785,0.374435908161104,0.989756791852415,0.584141218336299,0.653865883359686,NA,2 -0.806150843945478,0.678180137882009,0.973715249914676,0.106831170618534,0.300997588550672,0.750364477513358,0.491450151661411,0.0154496582690626,0.259265987202525,0.558927526231855,0.0112981905695051,NA,2 -1.44095642980371,0.917209735373035,0.341985862003639,0.268955789040774,0.109774171141908,0.239792765583843,0.454389738617465,0.522328067803755,0.60719471052289,0.118838940281421,0.388227568939328,NA,3 --3.25161689162916,0.922700179507956,0.0325096684973687,0.352068450069055,0.232990069780499,0.469084792770445,0.981463437899947,0.11463882913813,0.72538296924904,0.246744823874906,0.321137896971777,NA,3 --0.130522430553986,0.401394749758765,0.256456261267886,0.0354134733788669,0.978563639800996,0.239538576453924,0.814784180140123,0.735997994896024,0.749723166925833,0.62831072579138,0.890577062265947,NA,1 -0.669575467655283,0.00989005784504116,0.435264135012403,0.723876874428242,0.849379999563098,0.0575199017766863,0.957934653852135,0.837231518467888,0.46569788781926,0.891659042565152,0.783426602371037,NA,0.5 -2.54014096182668,0.305441607953981,0.0902002388611436,0.118577564833686,0.385022775735706,0.56807669182308,0.334235888672993,0.538188089616597,0.216295424615964,0.158172710798681,0.00234771450050175,NA,1 --0.193944994135837,0.00190780940465629,0.78280071169138,0.764838046627119,0.291868838947266,0.869910318404436,0.341377722565085,0.648056457517669,0.579640600830317,0.215270293876529,0.61375098163262,NA,0.5 -0.0208742133525005,0.863172167912126,0.504580818815157,0.374554778914899,0.836817204486579,0.87124543520622,0.670387550722808,0.737122401129454,0.827918927185237,0.599544936791062,0.942843936849385,NA,3 -0.765792314938424,0.566925200633705,0.849535879679024,0.34304199879989,0.847504819976166,0.892173214582726,0.758812437998131,0.214903841959313,0.45636634901166,0.484204821521416,0.207701456500217,NA,2 -0.352804369431417,0.364061176544055,0.670583426486701,0.117415519431233,0.108332704752684,0.316217254614457,0.767299061873928,0.895374773070216,0.667814850108698,0.922258591279387,0.0134780446533114,NA,1 --1.23698048076279,0.332583913113922,0.612572266953066,0.704983988078311,0.768699376378208,0.385655405465513,0.128826092928648,0.116410831455141,0.520903897006065,0.974680992309004,0.806146663846448,NA,1 --2.56441778784605,0.665565764764324,0.605427384609357,0.17659778567031,0.02090925257653,0.704631736734882,0.73883875226602,0.774222782580182,0.268567186314613,0.696540282806382,0.354841801803559,NA,2 -0.232796925229605,0.569475795142353,0.454856015508994,0.939597767777741,0.744496839120984,0.698145196773112,0.467724752146751,0.0180247377138585,0.297438253415748,0.308250970672816,0.523631568299606,NA,2 --0.916045983780381,0.398941761115566,0.661880399100482,0.385367383016273,0.0125482359435409,0.701195284491405,0.652126201428473,0.324592209421098,0.482771011535078,0.943977946648374,0.711980663239956,NA,1 --3.02511296110412,0.601308670360595,0.492553135380149,0.233871235977858,0.712123312056065,0.867592514958233,0.929739307146519,0.856446724850684,0.754480483010411,0.377171362750232,0.526562972692773,NA,2 -0.08116233752899,0.208068270701915,0.806491583352908,0.138934786897153,0.550345825962722,0.273357491008937,0.738135070772842,0.122066766722128,0.67383626755327,0.776206725975499,0.432750965002924,NA,0.5 --5.27518419045481,0.891456350916997,0.601628924021497,0.281817921902984,0.245754534145817,0.525102650281042,0.350016801385209,0.0317163458094001,0.342192857991904,0.69337944011204,0.535440781386569,NA,3 --0.195996325881689,0.138576569035649,0.187388508580625,0.262100108899176,0.263345927000046,0.493119575316086,0.152671873103827,0.504191888263449,0.913665950298309,0.165937364799902,0.0847142075654119,NA,0.5 -0.315579484938764,0.858387031825259,0.695295455399901,0.467401414411142,0.626965253613889,0.490166572621092,0.617864213185385,0.0769863030873239,0.542200545780361,0.219044300960377,0.0295198371168226,NA,3 -0.0267096393523584,0.266054383246228,0.523457160918042,0.538887435337529,0.597743078833446,0.296432407340035,0.807767712045461,0.84599154163152,0.397196589736268,0.82058397657238,0.093033802928403,NA,1 -1.25261939865097,0.60674773179926,0.706874470924959,0.674645616207272,0.156269168946892,0.858524578856304,0.966854283586144,0.777528064092621,0.547137490706518,0.0508977069985121,0.170303455553949,NA,2 -0.541883133495883,0.299106652615592,0.868335065431893,0.293283038539812,0.337412812747061,0.563003861811012,0.057874902850017,0.206769967218861,0.249845808371902,0.57238865736872,0.431905421428382,NA,1 --0.359359846361593,0.00028673093765974,0.444823984289542,0.267580098472536,0.619911819230765,0.348575697746128,0.445040695136413,0.298632511869073,0.349616114981472,0.0186277241446078,0.727784499060363,NA,0.5 --0.725046148314536,0.738201215397567,0.305303359869868,0.98667918308638,0.875479456968606,0.816655561095104,0.550264997407794,0.78673102799803,0.749039791757241,0.148614945588633,0.50864305277355,NA,2 -2.42662093734972,0.78392618172802,0.779760129982606,0.417675146833062,0.697084854589775,0.642540717031807,0.936560302507132,0.12054528738372,0.307862642919645,0.839434523368254,4.8645306378603e-05,NA,3 --0.841952793676179,0.248953337082639,0.903661191929132,0.214526182506233,0.640308956848457,0.630132407415658,0.557652691844851,0.294213210465387,0.575431503122672,0.117273130686954,0.394496359396726,NA,0.5 --1.03023205439958,0.724063246278092,0.797944512451068,0.531282962532714,0.761441437294707,0.882055644644424,0.695016658399254,0.902350852033123,0.740311119239777,0.755922955228016,0.0983525444753468,NA,2 --0.912771137967775,0.34925606357865,0.676785026676953,0.588860424933955,0.161323110340163,0.957519877469167,0.398865574738011,0.566670776112005,0.161247945856303,0.247955431696028,0.743449069792405,NA,1 -0.517313406285897,0.119058254174888,0.301976219983771,0.2837041572202,0.0562514583580196,0.732049201382324,0.543040792457759,0.178399549797177,0.637692917603999,0.176733382046223,0.913348912494257,NA,0.5 -0.871011941589735,0.308870760258287,0.508238948881626,0.217849156819284,0.755379884736612,0.308171590790153,0.877481517381966,0.958365337457508,0.897053161635995,0.609838765114546,0.357350649312139,NA,1 -0.723881692134752,0.1084195880685,0.285055675543845,0.494687871308997,0.959149545524269,0.867562021827325,0.437210090924054,0.0436040495987982,0.886218967847526,0.134117919020355,0.147130149416625,NA,0.5 --4.1056823910876,0.836923674680293,0.790600413223729,0.218965121312067,0.631649191724136,0.131414104020223,0.986164937028661,0.1794618004933,0.893512472510338,0.397628468694165,0.977409608894959,NA,3 -0.952558604299245,0.281533166533336,0.653331918641925,0.782094214344397,0.486817280761898,0.712452893378213,0.735165594611317,0.22127672447823,0.17275553708896,0.781067404197529,0.301315511809662,NA,1 -0.434834298048182,0.345692228991538,0.579240264836699,0.952218348626047,0.0265811623539776,0.292065045097843,0.140570497373119,0.594512204639614,0.894585548667237,0.163636759854853,0.791753190103918,NA,1 --0.617241139637955,0.24287026701495,0.347456730203703,0.352250093361363,0.567248777253553,0.5325484755449,0.362278310814872,0.104282488347962,0.6345406237524,0.298809389350936,0.061041493434459,NA,0.5 --1.10157854604848,0.589005693560466,0.738731898134574,0.411837223684415,0.0231877483893186,0.510611890815198,0.317921287380159,0.142743433825672,0.362699191318825,0.295025371946394,0.297999273752794,NA,2 -0.0779708729765081,0.61409965949133,0.267073255265132,0.644359648926184,0.435340049676597,0.663986770203337,0.853410524781793,0.439644190017134,0.0916713422629982,0.885482318233699,0.328195882029831,NA,2 --0.790086227163868,0.827004836406559,0.588833221467212,0.510606417665258,0.135837320005521,0.959990479284897,0.622058121720329,0.238031924469396,0.0791387532372028,0.875529285753146,0.385706771630794,NA,3 -0.0121011462499208,0.434019140200689,0.76129729533568,0.565473348135129,0.554319101618603,0.193900709273294,0.253057076362893,0.960415460402146,0.64092730381526,0.927356063388288,0.739957847632468,NA,1 -0.528668478450982,0.00797079806216061,0.620697325794026,0.059939440805465,0.933844800340012,0.282361896475777,0.798007321543992,0.109523738268763,0.205230085877702,0.786281060893089,0.626566218677908,NA,0.5 --1.97146007839586,0.426603920059279,0.232865846948698,0.667433485155925,0.0196325895376503,0.752369344001636,0.0936158304102719,0.0917298034764826,0.746732441009954,0.296932371798903,0.478665569564328,NA,1 -0.917133356519574,0.318299685139209,0.434513642685488,0.549943955615163,0.726315794279799,0.817645830567926,0.811867464100942,0.061435961863026,0.746087650069967,0.447872480610386,0.84255820303224,NA,1 --0.329040167848524,0.118530319537967,0.457290950231254,0.645931467181072,0.893649809062481,0.290868752868846,0.963444818975404,0.287232589442283,0.965041985502467,0.503362905001268,0.700591426342726,NA,0.5 -1.48193500973284,0.625231690006331,0.0553984474390745,0.649881738703698,0.299855779856443,0.496550764422864,0.941366601735353,0.854566443245858,0.833261088933796,0.97451618895866,0.931585984304547,NA,2 --0.173022349447719,0.164724288508296,0.869774119695649,0.643917094916105,0.35930937086232,0.131004156777635,0.581698802066967,0.991578063229099,0.646482613869011,0.973439331632107,0.251832834677771,NA,0.5 --0.798130578450863,0.413501289440319,0.191485705785453,0.66864021285437,0.944299905095249,0.781557427253574,0.426434542285278,0.155345749342814,0.47630034852773,0.368169168476015,0.250597581965849,NA,1 --0.574218404015675,0.242424165830016,0.58770715072751,0.515803768066689,0.381307470612228,0.00103132519870996,0.822859295178205,0.0108886968810111,0.230126534821466,0.557247446617112,0.905343563761562,NA,0.5 -1.20576697217052,0.481467565754429,0.384135053725913,0.250637572957203,0.0487530084792525,0.0894714090973139,0.907941385637969,0.130016501294449,0.108319341670722,0.00571742816828191,0.322211351711303,NA,1 -1.09418799497703,0.428736956324428,0.169671062380075,0.828763456083834,0.351603407878429,0.227558815153316,0.226003634510562,0.346984679810703,0.792112624039873,0.331346849445254,0.432414019247517,NA,1 -1.33295018743162,0.657001668587327,0.637372163590044,0.824361448409036,0.929858485469595,0.478111879900098,0.372787630185485,0.409159137168899,0.687029706314206,0.462824999121949,0.282154636690393,NA,2 -0.767515180489683,0.164000877644867,0.787021941738203,0.480641457485035,0.256527086952701,0.139046657830477,0.260776846436784,0.843071689130738,0.928756222594529,0.0495820019859821,0.623474525287747,NA,0.5 -5.91575131696734,0.968515180051327,0.366838250076398,0.0660362653434277,0.856167251477018,0.179680837085471,0.161809651181102,0.591778725851327,0.48686463618651,0.738822849933058,0.183114260667935,NA,3 --6.07690196330587,0.830175120383501,0.379230955149978,0.115117009729147,0.729175335494801,0.629532365594059,0.782002086285502,0.933919353410602,0.0596355667803437,0.724859771085903,0.652606250252575,NA,3 --0.313191580590458,0.433103942545131,0.486646147910506,0.0679896532092243,0.0214490725193173,0.0353633773047477,0.744490799959749,0.845955922966823,0.958894333336502,0.156010123435408,0.303289767587557,NA,1 --0.540923578326201,0.460490206023678,0.00578354438766837,0.550281241768971,0.788312985096127,0.210271368967369,0.815538283670321,0.28684966918081,0.275849351193756,0.469581529032439,0.647478682687506,NA,1 --0.673999243552475,0.570351638598368,0.868796721566468,0.776297046570107,0.343447036575526,0.135450173402205,0.647280380362645,0.0270581084769219,0.964931801892817,0.511921106604859,0.636638013413176,NA,2 -0.521317196587708,0.0333191037643701,0.283822370693088,0.192718317965046,0.948545307852328,0.824977292912081,0.370365360984579,0.515110278967768,0.214455303736031,0.224180643446743,0.711288265651092,NA,0.5 -0.464127896273477,0.150690599810332,0.122571487911046,0.888559450628236,0.731974716298282,0.0771893607452512,0.541291358415037,0.552816195879132,0.315732726827264,0.151362558128312,0.962228747783229,NA,0.5 -0.0280326256524612,0.218963563675061,0.667505384190008,0.811947909416631,0.687651550397277,0.0374882640317082,0.161606113892049,0.341489754617214,0.114423418417573,0.711193904979154,0.999109187629074,NA,0.5 -1.72545787431865,0.791984313167632,0.713207671185955,0.329182849498466,0.632728930795565,0.319493535207585,0.219288081862032,0.48391812061891,0.699484226293862,0.955756370443851,0.813189222943038,NA,3 -1.15486734783788,0.566656570881605,0.657263582805172,0.788281690096483,0.778249064926058,0.590916589833796,0.385298176435754,0.389463485917076,0.137170025613159,0.138196402462199,0.308245474472642,NA,2 --0.457464364994371,0.203198125585914,0.378648053156212,0.808898270130157,0.906652554869652,0.850757014239207,0.988703542388976,0.178554533980787,0.459032043116167,0.799958050018176,0.297242203727365,NA,0.5 -2.0642220969772,0.574996857671067,0.13869056943804,0.148214791202918,0.0713210175745189,0.281843064585701,0.699210103135556,0.611457530176267,0.412616135785356,0.312300917692482,0.753751856740564,NA,2 -1.79325254054518,0.779666494345292,0.699538435321301,0.961586186429486,0.347297596279532,0.0675467550754547,0.325070535531268,0.176488080061972,0.603379048407078,0.136019740719348,0.473342314828187,NA,3 --0.814168249805148,0.179700289387256,0.290472864639014,0.604973735287786,0.684216808062047,0.566976414760575,0.147398902801797,0.29908396420069,0.0406292479019612,0.475435412954539,0.30359836248681,NA,0.5 -1.05601002941817,0.974190507782623,0.59218838927336,0.234243621351197,0.808535494375974,0.187695320462808,0.772726303897798,0.996423880336806,0.286155195906758,0.0625078319571912,0.898713116999716,NA,3 --0.335678118214373,0.589294733013958,0.64874632912688,0.60927852592431,0.912218964891508,0.116446102270856,0.800978301558644,0.552153271622956,0.498927663313225,0.77924073045142,0.604476742446423,NA,2 --0.297220306313378,0.00104556139558554,0.243074154946953,0.28915421012789,0.584380616666749,0.680923626758158,0.190056581050158,0.489183940691873,0.543326803250238,0.462141246302053,0.0547776529565454,NA,0.5 -3.36381201170981,0.842413400299847,0.836705626221374,0.803220853442326,0.52237598830834,0.0978306015022099,0.44979960215278,0.623058009892702,0.0899901306256652,0.659136149799451,0.511611237656325,NA,3 --1.88385207395372,0.52834680210799,0.734877026872709,0.131855968618765,0.703349165152758,0.0538370700087398,0.573549640597776,0.794007008662447,0.812998212641105,0.404450054978952,0.43243193323724,NA,2 --1.73270176331594,0.779436560813338,0.115024698665366,0.0290894189383835,0.368165988940746,0.811495361849666,0.0700736299622804,0.615472083212808,0.257920050993562,0.869047976098955,0.349914842285216,NA,3 -1.37079153120135,0.570917901350185,0.187404554570094,0.398802037583664,0.846335949376225,0.934954999480397,0.529415836092085,0.910126697039232,0.118137815967202,0.111092046368867,0.248694343725219,NA,2 -0.571363043562507,0.331030112225562,0.450528397923335,0.336339824367315,0.16454178863205,0.430747980251908,0.505673100007698,0.328476364724338,0.21718189259991,0.586265194229782,0.447927470318973,NA,1 --1.79533836540329,0.976926734903827,0.0961627899669111,0.60899573401548,0.599624714581296,0.713024838129058,0.36587701109238,0.517652825685218,0.313225879799575,0.691182613139972,0.946866175159812,NA,3 -2.76102079506608,0.760612237034366,0.802744230488315,0.707125860266387,0.710110249230638,0.0485071069560945,0.337238987907767,0.603377134539187,0.396445233840495,0.336875891312957,0.0135681433603168,NA,3 -1.03587118074943,0.546186395455152,0.0895447288639843,0.102424295851961,0.124316787347198,0.398161188932136,0.0595820427406579,0.0317856874316931,0.711451485287398,0.184543622424826,0.570017953403294,NA,2 -2.64092013401424,0.813234215602279,0.346893396694213,0.999183400999755,0.163131871260703,0.109145125374198,0.742135232314467,0.926391132874414,0.354281723732129,0.518946023192257,0.143736189464107,NA,3 --2.14161378373943,0.623024221044034,0.987261440372095,0.240460366243497,0.0876346873119473,0.137773034861311,0.00308776297606528,0.764373373007402,0.797460315050557,0.443219825392589,0.516431802418083,NA,2 --0.193411209224147,0.0207313429564238,0.243699158774689,0.459949440322816,0.262547085061669,0.695130342617631,0.946504537714645,0.726618548389524,0.864191201748326,0.815897560212761,0.0842976626008749,NA,0.5 -1.04927342445141,0.472177557414398,0.0301995284389704,0.293144194176421,0.790906876791269,0.438358291517943,0.794757833238691,0.85613635671325,0.811791674233973,0.112252983730286,0.169568839715794,NA,1 -0.348988455997724,0.262080669403076,0.927174943732098,0.142983013764024,0.640390169108286,0.366609850898385,0.529223524965346,0.57928417250514,0.206402273848653,0.268847646424547,0.165483654243872,NA,1 -0.696440990671028,0.983129711588845,0.654424325563014,0.83093448006548,0.817050429061055,0.692395427962765,0.485632661730051,0.661732462001964,0.0443739434704185,0.631122295744717,0.821382548194379,NA,3 --1.50946435977094,0.524429301731288,0.394003928638995,0.612721437122673,0.641075712163001,0.8340331655927,0.81034871423617,0.378355852561072,0.118066052906215,0.0152956454548985,0.0607850980013609,NA,2 --3.60722631045085,0.973684054333717,0.862753849243745,0.470457507064566,0.516427976312116,0.571779370773584,0.00344765046611428,0.557507778517902,0.863539802376181,0.618254639208317,0.0438256817869842,NA,3 -0.640760110283842,0.3505244248081,0.388575355056673,0.525333723518997,0.90868836059235,0.656305950833485,0.87575338082388,0.145426131319255,0.142402743920684,0.722210387466475,0.646311188582331,NA,1 -1.0927036128643,0.0793833178468049,0.0104481617454439,0.93192859296687,0.779025187250227,0.358062380226329,0.498329946072772,0.557266015792266,0.794099885271862,0.103802968282253,0.489988661836833,NA,0.5 --0.960320030022167,0.231456351932138,0.585931064095348,0.872712701093405,0.684092679759488,0.747057758970186,0.33389916899614,0.693479184061289,0.234955767868087,0.0913671234156936,0.484689482487738,NA,0.5 -5.55515128084404,0.961973294848576,0.598567689536139,0.797411441337317,0.749971851240844,0.626418237108737,0.475173852639273,0.763921497855335,0.795191584387794,0.464224591152743,0.356469079386443,NA,3 -0.222792689673181,0.216883212793618,0.504161356948316,0.685418124077842,0.138950776075944,0.609323340700939,0.0815524500794709,0.887275154469535,0.673021834110841,0.229804648784921,0.545686026103795,NA,0.5 -0.493718034822244,0.11028397642076,0.503047540783882,0.179511502385139,0.763330884743482,0.77220833580941,0.934788197511807,0.494932206347585,0.731283217668533,0.855659510940313,0.0117787455674261,NA,0.5 -0.461772257413767,0.29381926660426,0.762450747890398,0.234901242889464,0.5249476833269,0.629562719725072,0.200397627661005,0.684597495477647,0.356038145022467,0.802739176666364,0.600334402406588,NA,1 --0.126129491991506,0.463731594849378,0.167885933071375,0.540383089799434,0.282209430588409,0.0468776472844183,0.37753825658001,0.563547776313499,0.0574275727849454,0.122192964889109,0.144121702294797,NA,1 -1.61069779366059,0.353434937307611,0.44708115211688,0.562454687664285,0.681164643960074,0.344466916518286,0.389643874252215,0.47427371609956,0.420238852966577,0.246512484736741,0.0113735438790172,NA,1 -0.455525567857585,0.656678552739322,0.520737201906741,0.440892208134755,0.383611467434093,0.853515934199095,0.312850524671376,0.188586953794584,0.611132161458954,0.416548165027052,0.966824663570151,NA,2 -0.752414003343478,0.368973962496966,0.144383963663131,0.298857402056456,0.603884643642232,0.948024961864576,0.0727454784791917,0.0918680941686034,0.0495259300805628,0.778334744274616,0.237199151189998,NA,1 --1.8071186112866,0.330021998845041,0.0695072312373668,0.405746232718229,0.476144076324999,0.162464414490387,0.143160074716434,0.2627956320066,0.371168155921623,0.219870826462284,0.593052309006453,NA,1 -0.338110168780804,0.0405488722026348,0.971289380686358,0.55798068922013,0.217713665682822,0.626417078776285,0.498581673717126,0.377851612633094,0.218278390355408,0.651240651495755,0.95426289155148,NA,0.5 --3.08507748542801,0.905532934004441,0.329712072620168,0.140193846542388,0.36389043321833,0.374093280639499,0.146446174709126,0.438101051608101,0.795116507913917,0.983287858776748,0.488525949651375,NA,3 --1.72372688590346,0.696682563750073,0.349387434776872,0.505626027705148,0.826131088426337,0.336876499466598,0.579624563921243,0.374704758403823,0.0984764297027141,0.158775766147301,0.86363190994598,NA,2 -0.613683233774074,0.0651458308566362,0.581464517861605,0.907430648803711,0.798097816295922,0.125102420337498,0.432380372891203,0.18545747292228,0.108076399890706,0.328964518383145,0.317265540128574,NA,0.5 --0.133225601541877,0.859262556303293,0.671681129839271,0.68166439072229,0.624605974880978,0.0363332864362746,0.559169779764488,0.0110217127948999,0.197365242987871,0.235231575090438,0.954349690815434,NA,3 -0.174705914799019,0.0484868276398629,0.72244808007963,0.412865192163736,0.177109800977632,0.49629073496908,0.43845453299582,0.883392751449719,0.354612664552405,0.244716308778152,0.773190965875983,NA,0.5 --2.11705848424155,0.688459891825914,0.221930157393217,0.0272207804955542,0.0866714252624661,0.0403791540302336,0.440649472177029,0.0105264887679368,0.100412974366918,0.790013117250055,0.270158890169114,NA,2 -1.1060152777824,0.375678987475112,0.924136486602947,0.474893233040348,0.824483483564109,0.0317611857317388,0.0541152430232614,0.202915545087308,0.701533750165254,0.888495072023943,0.528801799751818,NA,1 --1.40484738126697,0.914472201606259,0.389121263055131,0.115709668258205,0.70249339309521,0.88396455720067,0.0478117982856929,0.939067059895024,0.574597767787054,0.756754127331078,0.176177548244596,NA,3 --1.22766612067787,0.357021145289764,0.348640100797638,0.841263998765498,0.261813498334959,0.820658108917996,0.24100551684387,0.458479619817808,0.121229101205245,0.780505152186379,0.284189038909972,NA,1 --0.137006680090441,0.0419280277565122,0.456591938855127,0.769418158335611,0.3475969475694,0.325503826374188,0.268078948371112,0.801353276707232,0.0109000348020345,0.909261407796293,0.970131604699418,NA,0.5 --0.0472805089628641,0.458404157776386,0.509212339995429,0.472755545051768,0.89819523692131,0.00172450905665755,0.195178902009502,0.491253996500745,0.602235041558743,0.000361819053068757,0.810925672529265,NA,1 --6.69280752542676,0.794113770360127,0.283597531029955,0.415720561984926,0.386682669632137,0.543149482225999,0.528418547241017,0.59079469111748,0.82328638038598,0.7218906681519,0.762335718609393,NA,3 --0.791993317715144,0.94858734565787,0.862443518359214,0.336455043870956,0.178935058414936,0.876441803527996,0.886401900788769,0.0696042117197067,0.201479215407744,0.579562027938664,0.916529270121828,NA,3 --3.5798937723398,0.518584945704788,0.341426242142916,0.980938139371574,0.025910884141922,0.703400827012956,0.611595463240519,0.540800383547321,0.108336265431717,0.802920109592378,0.717154868878424,NA,2 -1.34179692842582,0.256257413187996,0.137025708565488,0.189194887410849,0.356893142918125,0.456141719361767,0.802874229149893,0.286960861412808,0.539246008498594,0.623640240402892,0.202499657869339,NA,1 -0.574692809911355,0.367721185088158,0.189689038787037,0.650288314092904,0.916537459241226,0.189138949615881,0.176210168516263,0.626079858280718,0.105365733848885,0.0375174721702933,0.556417517596856,NA,1 -0.291989692067655,0.374401310924441,0.173651255434379,0.666860758094117,0.863264636369422,0.927415754180402,0.457678058184683,0.499020219780505,0.751850564265624,0.987137568416074,0.329214916797355,NA,1 --1.0948854251991,0.480366722913459,0.382694880710915,0.93061161483638,0.545312937349081,0.46468991599977,0.467403194168583,0.945679706521332,0.918626469327137,0.6478590965271,0.0622258309740573,NA,1 -0.549466325611295,0.832923032343388,0.586885153781623,0.856621402082965,0.0974842766299844,0.193075296701863,0.797113571548834,0.331570270936936,0.681197754340246,0.456899536773562,0.216078756842762,NA,3 --1.18444309638044,0.987224767683074,0.0441201773937792,0.150253226514906,0.924849382368848,0.243760817451403,0.601127431960776,0.941083330428228,0.0376720677595586,0.590444662142545,0.594265560386702,NA,3 --2.6287892127699,0.511187988799065,0.17760677007027,0.0900412548799068,0.707964851288125,0.772246306762099,0.0507865459658206,0.880909067578614,0.565439280355349,0.325563751393929,0.625947890337557,NA,2 From 8080c8cd94df53a3b2a8d456f809cbf35f92700f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 12:02:11 -0400 Subject: [PATCH 02/64] Add more C++ and VSCode infrastructure for RFC 0004 --- .devcontainer/Dockerfile | 2 +- .devcontainer/devcontainer.json | 39 ++++-- .gitignore | 6 +- .vscode/extensions.json | 8 ++ .vscode/launch.json | 69 ++++++++++ .vscode/tasks.json | 69 ++++++++++ CMakeLists.txt | 61 ++++++--- CMakePresets.json | 106 +++++++++++++++ debug/bart_debug.cpp | 152 +++++++++++++++++++++ debug/bcf_debug.cpp | 232 ++++++++++++++++++++++++++++++++ 10 files changed, 708 insertions(+), 36 deletions(-) create mode 100644 .vscode/extensions.json create mode 100644 .vscode/launch.json create mode 100644 .vscode/tasks.json create mode 100644 CMakePresets.json create mode 100644 debug/bart_debug.cpp create mode 100644 debug/bcf_debug.cpp diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 8af2f042..33602460 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -5,7 +5,7 @@ RUN apt-get update -y && \ valgrind && \ rm -rf /var/lib/apt/lists/* -ARG REINSTALL_CMAKE_VERSION_FROM_SOURCE="3.22.2" +ARG REINSTALL_CMAKE_VERSION_FROM_SOURCE="3.29.3" # Optionally install the cmake for vcpkg COPY ./reinstall-cmake.sh /tmp/ diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index b519e257..f496e8ba 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,23 +1,34 @@ // For format details, see https://aka.ms/devcontainer.json. For config options, see the // README at: https://github.com/devcontainers/templates/tree/main/src/cpp { - "name": "C++", + "name": "stochtree C++ Dev", "build": { "dockerfile": "Dockerfile" - } + }, - // Features to add to the dev container. More info: https://containers.dev/features. - // "features": {}, + "features": { + "ghcr.io/devcontainers/features/git:1": {}, + "ghcr.io/devcontainers/features/github-cli:1": {} + }, - // Use 'forwardPorts' to make a list of ports inside the container available locally. - // "forwardPorts": [], + "customizations": { + "vscode": { + "extensions": [ + "ms-vscode.cmake-tools", + "llvm-vs-code-extensions.vscode-clangd", + "vadimcn.vscode-lldb" + ], + "settings": { + "clangd.path": "/usr/bin/clangd", + "clangd.arguments": ["--compile-commands-dir=${workspaceFolder}/build"], + "cmake.configureOnOpen": true, + "cmake.defaultConfigurePreset": "dev-quick", + "cmake.defaultBuildPreset": "dev-quick" + } + } + }, - // Use 'postCreateCommand' to run commands after the container is created. - // "postCreateCommand": "gcc -v", - - // Configure tool-specific properties. - // "customizations": {}, - - // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. - // "remoteUser": "root" + // Configure dev build (no test download) on container creation. + // Switch to the "dev" preset and re-run cmake when you need GoogleTest. + "postCreateCommand": "cmake --preset dev-quick && cmake --build --preset dev-quick" } diff --git a/.gitignore b/.gitignore index 0b5146bc..6acc4362 100644 --- a/.gitignore +++ b/.gitignore @@ -4,9 +4,13 @@ *.DS_Store lib/ build/ -.vscode/ +.vscode/positron/ xcode/ *.json +!.vscode/extensions.json +!.vscode/tasks.json +!.vscode/launch.json +!CMakePresets.json !test/R/testthat/fixtures/*.json !test/python/fixtures/*.json .vs/ diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..33e5b7db --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,8 @@ +{ + "recommendations": [ + "ms-vscode.cmake-tools", + "vadimcn.vscode-lldb", + "llvm-vs-code-extensions.vscode-clangd", + "Posit.air-vscode" + ] +} diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..4b2abca8 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,69 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "bart_debug (macOS)", + "type": "lldb", + "request": "launch", + "program": "${workspaceFolder}/build/bart_debug", + "args": ["0"], + "cwd": "${workspaceFolder}", + "preLaunchTask": "CMake: Build (dev-quick)" + }, + { + "name": "bcf_debug (macOS)", + "type": "lldb", + "request": "launch", + "program": "${workspaceFolder}/build/bcf_debug", + "args": ["0"], + "cwd": "${workspaceFolder}", + "preLaunchTask": "CMake: Build (dev-quick)" + }, + { + "name": "teststochtree (macOS)", + "type": "lldb", + "request": "launch", + "program": "${workspaceFolder}/build/teststochtree", + "args": [], + "cwd": "${workspaceFolder}", + "preLaunchTask": "CMake: Build (dev)" + }, + { + "name": "bart_debug (Linux/Container)", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/build/bart_debug", + "args": ["0"], + "cwd": "${workspaceFolder}", + "MIMode": "gdb", + "preLaunchTask": "CMake: Build (dev-quick)" + }, + { + "name": "bcf_debug (Linux/Container)", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/build/bcf_debug", + "args": ["0"], + "cwd": "${workspaceFolder}", + "MIMode": "gdb", + "preLaunchTask": "CMake: Build (dev-quick)" + }, + { + "name": "teststochtree (Linux/Container)", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/build/teststochtree", + "args": [], + "cwd": "${workspaceFolder}", + "MIMode": "gdb", + "preLaunchTask": "CMake: Build (dev)" + }, + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 00000000..585737d9 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,69 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "CMake: Configure (dev)", + "type": "shell", + "command": "cmake --preset dev", + "group": "build", + "problemMatcher": ["$gcc"], + "presentation": { "reveal": "always", "panel": "shared" } + }, + { + "label": "CMake: Configure (dev-quick)", + "type": "shell", + "command": "cmake --preset dev-quick", + "group": "build", + "problemMatcher": ["$gcc"], + "presentation": { "reveal": "always", "panel": "shared" } + }, + { + "label": "CMake: Build (dev)", + "type": "shell", + "command": "cmake --build --preset dev", + "group": { "kind": "build", "isDefault": true }, + "problemMatcher": ["$gcc"], + "presentation": { "reveal": "always", "panel": "shared" } + }, + { + "label": "CMake: Build (dev-quick)", + "type": "shell", + "command": "cmake --build --preset dev-quick", + "group": "build", + "problemMatcher": ["$gcc"], + "presentation": { "reveal": "always", "panel": "shared" } + }, + { + "label": "CMake: Build (release)", + "type": "shell", + "command": "cmake --build --preset release", + "group": "build", + "problemMatcher": ["$gcc"], + "presentation": { "reveal": "always", "panel": "shared" } + }, + { + "label": "CMake: Build (sanitizer)", + "type": "shell", + "command": "cmake --build --preset sanitizer", + "group": "build", + "problemMatcher": ["$gcc"], + "presentation": { "reveal": "always", "panel": "shared" } + }, + { + "label": "CTest: Run All", + "type": "shell", + "command": "ctest --preset dev", + "group": { "kind": "test", "isDefault": true }, + "problemMatcher": [], + "presentation": { "reveal": "always", "panel": "shared" } + }, + { + "label": "CTest: Run All (sanitizer)", + "type": "shell", + "command": "ctest --preset sanitizer", + "group": "test", + "problemMatcher": [], + "presentation": { "reveal": "always", "panel": "shared" } + } + ] +} diff --git a/CMakeLists.txt b/CMakeLists.txt index b6471bd2..0f8a70ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -130,13 +130,12 @@ set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/build) # Aggregate the source files underpinning the implementation in the C++ library file( - GLOB + GLOB SOURCES src/container.cpp src/cutpoint_candidates.cpp src/data.cpp src/io.cpp - src/json11.cpp src/leaf_model.cpp src/ordinal_sampler.cpp src/partition_tracker.cpp @@ -177,18 +176,29 @@ endif() # Build C++ test program if(BUILD_TEST) - # Check if user specified a local clone of the GoogleTest repo, use Github repo if not - if (NOT DEFINED GOOGLETEST_GIT_REPO) - set(GOOGLETEST_GIT_REPO https://github.com/google/googletest.git) - endif() - - # Fetch and install GoogleTest dependency include(FetchContent) - FetchContent_Declare( - googletest - GIT_REPOSITORY ${GOOGLETEST_GIT_REPO} - GIT_TAG 6910c9d9165801d8827d628cb72eb7ea9dd538c5 # release-1.16.0 - ) + + set(GTEST_SUBMODULE_DIR "${PROJECT_SOURCE_DIR}/deps/googletest") + if(EXISTS "${GTEST_SUBMODULE_DIR}/CMakeLists.txt") + # Use the local submodule — no network required. + # Initialize with: git submodule update --init deps/googletest + message(STATUS "GoogleTest: using local submodule at ${GTEST_SUBMODULE_DIR}") + FetchContent_Declare( + googletest + SOURCE_DIR "${GTEST_SUBMODULE_DIR}" + ) + else() + # Fall back to GitHub fetch (CI, shallow clones, or submodule not initialized). + if (NOT DEFINED GOOGLETEST_GIT_REPO) + set(GOOGLETEST_GIT_REPO https://github.com/google/googletest.git) + endif() + message(STATUS "GoogleTest: fetching from ${GOOGLETEST_GIT_REPO}") + FetchContent_Declare( + googletest + GIT_REPOSITORY ${GOOGLETEST_GIT_REPO} + GIT_TAG 6910c9d9165801d8827d628cb72eb7ea9dd538c5 # release-1.16.0 + ) + endif() # For Windows: Prevent overriding the parent project's compiler/linker settings set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) @@ -209,17 +219,28 @@ if(BUILD_TEST) gtest_discover_tests(teststochtree) endif() -# Standalone C++ Program for Debugging +# Standalone C++ Programs for Debugging if(BUILD_DEBUG_TARGETS) - # Build test suite - add_executable(debugstochtree debug/api_debug.cpp) set(StochTree_DEBUG_HEADER_DIR ${PROJECT_SOURCE_DIR}/debug) + + # BART debug driver + add_executable(bart_debug debug/bart_debug.cpp) + if(USE_OPENMP) + target_include_directories(bart_debug PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR}) + target_link_libraries(bart_debug PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY}) + else() + target_include_directories(bart_debug PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR}) + target_link_libraries(bart_debug PRIVATE stochtree_objs) + endif() + + # BCF debug driver + add_executable(bcf_debug debug/bcf_debug.cpp) if(USE_OPENMP) - target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR}) - target_link_libraries(debugstochtree PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY}) + target_include_directories(bcf_debug PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR}) + target_link_libraries(bcf_debug PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY}) else() - target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR}) - target_link_libraries(debugstochtree PRIVATE stochtree_objs) + target_include_directories(bcf_debug PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR}) + target_link_libraries(bcf_debug PRIVATE stochtree_objs) endif() endif() diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 00000000..42c7644c --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,106 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "dev", + "displayName": "Dev (Debug + Tests)", + "description": "Debug build with GoogleTest and debug targets — primary development preset", + "binaryDir": "${sourceDir}/build", + "cacheVariables": { + "USE_DEBUG": "ON", + "BUILD_DEBUG_TARGETS": "ON", + "BUILD_TEST": "ON", + "BUILD_PYTHON": "OFF", + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "FETCHCONTENT_UPDATES_DISCONNECTED": "ON" + } + }, + { + "name": "dev-quick", + "displayName": "Dev (Debug, no tests)", + "description": "Debug build with debug targets, skips GoogleTest download — faster configure", + "binaryDir": "${sourceDir}/build", + "cacheVariables": { + "USE_DEBUG": "ON", + "BUILD_DEBUG_TARGETS": "ON", + "BUILD_TEST": "OFF", + "BUILD_PYTHON": "OFF", + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON" + } + }, + { + "name": "release", + "displayName": "Release", + "description": "Optimized release build", + "binaryDir": "${sourceDir}/build-release", + "cacheVariables": { + "USE_DEBUG": "OFF", + "BUILD_DEBUG_TARGETS": "OFF", + "BUILD_TEST": "OFF", + "BUILD_PYTHON": "OFF", + "CMAKE_EXPORT_COMPILE_COMMANDS": "OFF" + } + }, + { + "name": "sanitizer", + "displayName": "Sanitizer (ASAN + UBSAN)", + "description": "Debug build with address and undefined-behavior sanitizers", + "binaryDir": "${sourceDir}/build-sanitizer", + "cacheVariables": { + "USE_DEBUG": "ON", + "USE_SANITIZER": "ON", + "BUILD_DEBUG_TARGETS": "ON", + "BUILD_TEST": "ON", + "BUILD_PYTHON": "OFF", + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "FETCHCONTENT_UPDATES_DISCONNECTED": "ON" + } + } + ], + "buildPresets": [ + { + "name": "dev", + "displayName": "Dev (Debug + Tests)", + "configurePreset": "dev", + "jobs": 0 + }, + { + "name": "dev-quick", + "displayName": "Dev (Debug, no tests)", + "configurePreset": "dev-quick", + "jobs": 0 + }, + { + "name": "release", + "displayName": "Release", + "configurePreset": "release", + "jobs": 0 + }, + { + "name": "sanitizer", + "displayName": "Sanitizer", + "configurePreset": "sanitizer", + "jobs": 0 + } + ], + "testPresets": [ + { + "name": "dev", + "displayName": "Dev tests", + "configurePreset": "dev", + "output": { + "outputOnFailure": true, + "verbosity": "default" + } + }, + { + "name": "sanitizer", + "displayName": "Sanitizer tests", + "configurePreset": "sanitizer", + "output": { + "outputOnFailure": true, + "verbosity": "default" + } + } + ] +} diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp new file mode 100644 index 00000000..eeba33ae --- /dev/null +++ b/debug/bart_debug.cpp @@ -0,0 +1,152 @@ +/* + * BART debug driver. The first CLI argument selects the scenario (default: 0). + * + * Usage: bart_debug [scenario] + * 0 Homoskedastic constant-leaf BART + * DGP: y = sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + eps, eps ~ N(0,1) + * + * Add scenarios here as the BARTSampler API develops (heteroskedastic, + * random effects, multivariate leaf, etc.). + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +static constexpr double kPi = 3.14159265358979323846; + +// ---- Data ------------------------------------------------------------ + +struct Dataset { + Eigen::Matrix X; + Eigen::VectorXd y; +}; + +// DGP: y = sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) +Dataset generate_data(int n, int p, std::mt19937& rng) { + std::uniform_real_distribution unif(0.0, 1.0); + std::normal_distribution normal(0.0, 1.0); + Dataset d; + d.X.resize(n, p); + d.y.resize(n); + for (int i = 0; i < n; i++) + for (int j = 0; j < p; j++) + d.X(i, j) = unif(rng); + for (int i = 0; i < n; i++) + d.y(i) = std::sin(2.0 * kPi * d.X(i, 0)) + + 0.5 * d.X(i, 1) + - 1.5 * d.X(i, 2) + + normal(rng); + return d; +} + +// ---- Scenario 0: homoskedastic constant-leaf BART ------------------- + +void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc) { + constexpr int num_threads = 1; + constexpr int cutpoint_grid_size = 100; + std::mt19937 rng(42); + + Dataset data = generate_data(n, p, rng); + double y_bar = data.y.mean(); + Eigen::VectorXd resid_vec = data.y.array() - y_bar; + + StochTree::ForestDataset dataset; + dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + StochTree::ColumnVector residual(resid_vec.data(), n); + + std::vector feature_types(p, StochTree::FeatureType::kNumeric); + std::vector var_weights(p, 1.0 / p); + std::vector sweep_indices; + + StochTree::TreePrior tree_prior(0.95, 2.0, /*min_samples_leaf=*/5); + StochTree::ForestContainer forest_samples(num_trees, /*output_dim=*/1, /*leaf_constant=*/true, /*exponentiated=*/false); + StochTree::TreeEnsemble active_forest(num_trees, 1, true, false); + StochTree::ForestTracker tracker(dataset.GetCovariates(), feature_types, num_trees, n); + + double leaf_scale = 1.0 / num_trees; + StochTree::GaussianConstantLeafModel leaf_model(leaf_scale); + + double global_variance = 1.0; + constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior + StochTree::GlobalHomoskedasticVarianceModel var_model; + + // GFR warmup — no samples stored + std::cout << "[GFR] " << num_gfr << " warmup iterations...\n"; + bool pre_initialized = false; + for (int i = 0; i < num_gfr; i++) { + StochTree::GFRSampleOneIter< + StochTree::GaussianConstantLeafModel, + StochTree::GaussianConstantSuffStat>( + active_forest, tracker, forest_samples, leaf_model, + dataset, residual, tree_prior, rng, + var_weights, sweep_indices, global_variance, feature_types, + cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, + /*backfitting=*/true, /*num_features_subsample=*/-1, num_threads); + global_variance = var_model.SampleVarianceParameter( + residual.GetData(), a_sigma, b_sigma, rng); + pre_initialized = true; + } + + // MCMC — store samples + std::cout << "[MCMC] " << num_mcmc << " sampling iterations...\n"; + for (int i = 0; i < num_mcmc; i++) { + StochTree::MCMCSampleOneIter< + StochTree::GaussianConstantLeafModel, + StochTree::GaussianConstantSuffStat>( + active_forest, tracker, forest_samples, leaf_model, + dataset, residual, tree_prior, rng, + var_weights, sweep_indices, global_variance, + /*keep_forest=*/true, /*pre_initialized=*/true, + /*backfitting=*/true, num_threads); + global_variance = var_model.SampleVarianceParameter( + residual.GetData(), a_sigma, b_sigma, rng); + } + + // Posterior predictions: column-major, element [j*n + i] = sample j, obs i + std::vector preds = forest_samples.Predict(dataset); + double rmse_sum = 0.0; + for (int i = 0; i < n; i++) { + double mu_hat = y_bar; + for (int j = 0; j < num_mcmc; j++) + mu_hat += preds[static_cast(j * n + i)] / num_mcmc; + double err = mu_hat - data.y(i); + rmse_sum += err * err; + } + + std::cout << "\nScenario 0 (HomoskedasticBART):\n" + << " RMSE: " << std::sqrt(rmse_sum / n) << "\n" + << " sigma (last sample): " << std::sqrt(global_variance) << "\n" + << " sigma (truth): 1.0\n"; +} + +// ---- Main ----------------------------------------------------------- + +int main(int argc, char** argv) { + int scenario = 0; + if (argc > 1) scenario = std::stoi(argv[1]); + + constexpr int n = 200, p = 5, num_trees = 200, num_gfr = 20, num_mcmc = 100; + + switch (scenario) { + case 0: + run_scenario_0(n, p, num_trees, num_gfr, num_mcmc); + break; + default: + std::cerr << "Unknown scenario " << scenario + << ". Available scenarios: 0 (HomoskedasticBART)\n"; + return 1; + } + return 0; +} diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp new file mode 100644 index 00000000..6911097d --- /dev/null +++ b/debug/bcf_debug.cpp @@ -0,0 +1,232 @@ +/* + * BCF debug driver. The first CLI argument selects the scenario (default: 0). + * + * Usage: bcf_debug [scenario] + * 0 Two-forest BCF: constant-leaf mu, univariate-leaf tau (Z as basis) + * DGP: mu(x) = 2*sin(pi*x1) + 0.5*x2 + * tau(x) = 1 + x3 + * z ~ Bernoulli(0.5) + * y = mu(x) + tau(x)*z + N(0, 0.5^2) + * + * Add scenarios here as the BCFSampler API develops (heteroskedastic, + * random effects, propensity weighting, etc.). + * + * Algorithm overview + * ------------------ + * Both forests share a single ColumnVector residual. Alternating GFR/MCMC + * steps for mu and tau each run backfitting, so the residual after each + * step correctly reflects the other forest's current contribution: + * + * After mu step: residual ≈ y - y_bar - mu_hat + * After tau step: residual ≈ y - y_bar - mu_hat - tau_hat*z + * + * The tau forest uses z as a univariate basis (AddBasis), so its prediction + * for observation i is tau_leaf(i) * z(i), and backfitting is z-aware. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +static constexpr double kPi = 3.14159265358979323846; + +// ---- Data ------------------------------------------------------------ + +struct BCFDataset { + Eigen::Matrix X; + Eigen::VectorXd y; + Eigen::VectorXd z; + Eigen::VectorXd mu_true; + Eigen::VectorXd tau_true; +}; + +BCFDataset generate_data(int n, int p, std::mt19937& rng) { + std::uniform_real_distribution unif(0.0, 1.0); + std::normal_distribution normal(0.0, 1.0); + std::bernoulli_distribution bern(0.5); + + BCFDataset d; + d.X.resize(n, p); + d.y.resize(n); + d.z.resize(n); + d.mu_true.resize(n); + d.tau_true.resize(n); + + for (int i = 0; i < n; i++) + for (int j = 0; j < p; j++) + d.X(i, j) = unif(rng); + + for (int i = 0; i < n; i++) { + d.z(i) = bern(rng) ? 1.0 : 0.0; + d.mu_true(i) = 2.0 * std::sin(kPi * d.X(i, 0)) + 0.5 * d.X(i, 1); + d.tau_true(i) = 1.0 + d.X(i, 2); + d.y(i) = d.mu_true(i) + d.tau_true(i) * d.z(i) + 0.5 * normal(rng); + } + return d; +} + +// ---- Scenario 0: constant-leaf mu + univariate-leaf tau (Z basis) --- + +void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc) { + constexpr int num_threads = 1; + constexpr int cutpoint_grid_size = 100; + std::mt19937 rng(42); + + BCFDataset data = generate_data(n, p, rng); + double y_bar = data.y.mean(); + Eigen::VectorXd resid_vec = data.y.array() - y_bar; + + // Mu dataset: X covariates only + StochTree::ForestDataset dataset_mu; + dataset_mu.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + + // Tau dataset: X covariates + Z as univariate basis + StochTree::ForestDataset dataset_tau; + dataset_tau.AddCovariates(data.X.data(), n, p, true); + dataset_tau.AddBasis(data.z.data(), n, /*num_col=*/1, /*row_major=*/false); + + // Shared residual + StochTree::ColumnVector residual(resid_vec.data(), n); + + std::vector feature_types(p, StochTree::FeatureType::kNumeric); + std::vector var_weights(p, 1.0 / p); + std::vector sweep_indices; + + StochTree::TreePrior tree_prior(0.95, 2.0, /*min_samples_leaf=*/5); + + // Mu forest: constant-leaf + StochTree::ForestContainer mu_samples(num_trees, 1, /*leaf_constant=*/true, /*exponentiated=*/false); + StochTree::TreeEnsemble mu_forest(num_trees, 1, true, false); + StochTree::ForestTracker mu_tracker(dataset_mu.GetCovariates(), feature_types, num_trees, n); + double mu_leaf_scale = 1.0 / num_trees; + StochTree::GaussianConstantLeafModel mu_leaf_model(mu_leaf_scale); + + // Tau forest: univariate regression leaf (prediction = leaf_param * z) + StochTree::ForestContainer tau_samples(num_trees, 1, /*leaf_constant=*/false, /*exponentiated=*/false); + StochTree::TreeEnsemble tau_forest(num_trees, 1, false, false); + StochTree::ForestTracker tau_tracker(dataset_tau.GetCovariates(), feature_types, num_trees, n); + double tau_leaf_scale = 1.0 / num_trees; + StochTree::GaussianUnivariateRegressionLeafModel tau_leaf_model(tau_leaf_scale); + + double global_variance = 1.0; + constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior + StochTree::GlobalHomoskedasticVarianceModel var_model; + + // GFR warmup — no samples stored + std::cout << "[GFR] " << num_gfr << " warmup iterations...\n"; + bool pre_mu = false, pre_tau = false; + for (int i = 0; i < num_gfr; i++) { + StochTree::GFRSampleOneIter< + StochTree::GaussianConstantLeafModel, + StochTree::GaussianConstantSuffStat>( + mu_forest, mu_tracker, mu_samples, mu_leaf_model, + dataset_mu, residual, tree_prior, rng, + var_weights, sweep_indices, global_variance, feature_types, + cutpoint_grid_size, /*keep_forest=*/false, pre_mu, + /*backfitting=*/true, /*num_features_subsample=*/-1, num_threads); + pre_mu = true; + + StochTree::GFRSampleOneIter< + StochTree::GaussianUnivariateRegressionLeafModel, + StochTree::GaussianUnivariateRegressionSuffStat>( + tau_forest, tau_tracker, tau_samples, tau_leaf_model, + dataset_tau, residual, tree_prior, rng, + var_weights, sweep_indices, global_variance, feature_types, + cutpoint_grid_size, false, pre_tau, + true, -1, num_threads); + pre_tau = true; + + global_variance = var_model.SampleVarianceParameter( + residual.GetData(), a_sigma, b_sigma, rng); + } + + // MCMC — store samples + std::cout << "[MCMC] " << num_mcmc << " sampling iterations...\n"; + for (int i = 0; i < num_mcmc; i++) { + StochTree::MCMCSampleOneIter< + StochTree::GaussianConstantLeafModel, + StochTree::GaussianConstantSuffStat>( + mu_forest, mu_tracker, mu_samples, mu_leaf_model, + dataset_mu, residual, tree_prior, rng, + var_weights, sweep_indices, global_variance, + /*keep_forest=*/true, /*pre_initialized=*/true, + /*backfitting=*/true, num_threads); + + StochTree::MCMCSampleOneIter< + StochTree::GaussianUnivariateRegressionLeafModel, + StochTree::GaussianUnivariateRegressionSuffStat>( + tau_forest, tau_tracker, tau_samples, tau_leaf_model, + dataset_tau, residual, tree_prior, rng, + var_weights, sweep_indices, global_variance, + true, true, true, num_threads); + + global_variance = var_model.SampleVarianceParameter( + residual.GetData(), a_sigma, b_sigma, rng); + } + + // Posterior predictions + // mu_preds[j*n + i] = mu_hat for sample j, obs i (column-major) + // tau_preds[j*n + i] = tau_hat(i)*z(i) (since basis is z) + std::vector mu_preds = mu_samples.Predict(dataset_mu); + std::vector tau_preds = tau_samples.Predict(dataset_tau); + + double mu_rmse_sum = 0.0; + double tau_rmse_sum = 0.0; + int n_treated = 0; + + for (int i = 0; i < n; i++) { + double mu_hat = y_bar; + for (int j = 0; j < num_mcmc; j++) + mu_hat += mu_preds[static_cast(j * n + i)] / num_mcmc; + double mu_err = mu_hat - data.mu_true(i); + mu_rmse_sum += mu_err * mu_err; + + // For z=1: tau_preds = tau_hat * 1 = tau_hat, so we can evaluate CATE + if (data.z(i) > 0.5) { + double tau_hat = 0.0; + for (int j = 0; j < num_mcmc; j++) + tau_hat += tau_preds[static_cast(j * n + i)] / num_mcmc; + double tau_err = tau_hat - data.tau_true(i); + tau_rmse_sum += tau_err * tau_err; + n_treated++; + } + } + + std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" + << " mu RMSE: " << std::sqrt(mu_rmse_sum / n) << "\n" + << " tau RMSE (treated): " + << (n_treated > 0 ? std::sqrt(tau_rmse_sum / n_treated) : 0.0) << "\n" + << " sigma (last sample): " << std::sqrt(global_variance) << "\n" + << " sigma (truth): 0.5\n"; +} + +// ---- Main ----------------------------------------------------------- + +int main(int argc, char** argv) { + int scenario = 0; + if (argc > 1) scenario = std::stoi(argv[1]); + + constexpr int n = 200, p = 5, num_trees = 200, num_gfr = 20, num_mcmc = 100; + + switch (scenario) { + case 0: + run_scenario_0(n, p, num_trees, num_gfr, num_mcmc); + break; + default: + std::cerr << "Unknown scenario " << scenario + << ". Available scenarios: 0 (BasicBCF)\n"; + return 1; + } + return 0; +} From c3296089e9c7b955c50027ea1c7474153015f3a6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 12:06:17 -0400 Subject: [PATCH 03/64] Added clang formatting and linting settings --- .clang-format | 5 +++++ .clang-tidy | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 .clang-format create mode 100644 .clang-tidy diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..b6a55284 --- /dev/null +++ b/.clang-format @@ -0,0 +1,5 @@ +BasedOnStyle: Google +IndentWidth: 2 +AccessModifierOffset: -1 +ColumnLimit: 0 +SortIncludes: false diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 00000000..1f3a66ea --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,32 @@ +--- +# Conservative clang-tidy config for stochtree. +# Focuses on real bugs and performance issues; style is handled by clang-format. +# To run manually: clang-tidy -- (or via clangd in the editor) + +Checks: > + clang-analyzer-*, + bugprone-branch-clone, + bugprone-copy-constructor-init, + bugprone-dangling-handle, + bugprone-incorrect-roundings, + bugprone-infinite-loop, + bugprone-redundant-branch-condition, + bugprone-suspicious-include, + bugprone-use-after-move, + modernize-redundant-void-arg, + modernize-use-emplace, + modernize-use-nullptr, + modernize-use-override, + performance-for-range-copy, + performance-move-const-arg, + performance-unnecessary-copy-initialization, + -clang-analyzer-optin.performance.Padding + +# Leave empty to warn only, not fail the build. +WarningsAsErrors: "" + +# Only surface warnings for project headers, not dependencies. +HeaderFilterRegex: "include/stochtree/.*" + +# Respect .clang-format for any fixes clang-tidy applies. +FormatStyle: file From 0e0051383adcb8fdafd539e47d45f2cbde595bbe Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 12:14:40 -0400 Subject: [PATCH 04/64] Auto-formatted C++ headers to the style standardized in `.clang-format` --- include/stochtree/category_tracker.h | 52 +-- include/stochtree/common.h | 213 ++++++------ include/stochtree/container.h | 90 ++--- include/stochtree/cutpoint_candidates.h | 78 ++--- include/stochtree/data.h | 216 ++++++------ include/stochtree/discrete_sampler.h | 12 +- include/stochtree/distributions.h | 46 +-- include/stochtree/ensemble.h | 166 ++++----- include/stochtree/export.h | 14 +- include/stochtree/gamma_sampler.h | 6 +- include/stochtree/ig_sampler.h | 12 +- include/stochtree/io.h | 445 ++++++++++++------------ include/stochtree/leaf_model.h | 109 +++--- include/stochtree/log.h | 28 +- include/stochtree/mainpage.h | 100 +++--- include/stochtree/meta.h | 42 +-- include/stochtree/normal_sampler.h | 14 +- include/stochtree/openmp_utils.h | 94 ++--- include/stochtree/ordinal_sampler.h | 28 +- include/stochtree/partition_tracker.h | 128 +++---- include/stochtree/prior.h | 49 +-- include/stochtree/random.h | 51 ++- include/stochtree/random_effects.h | 81 +++-- include/stochtree/tree.h | 235 ++++++------- include/stochtree/tree_sampler.h | 310 ++++++++--------- include/stochtree/variance_model.h | 14 +- 26 files changed, 1342 insertions(+), 1291 deletions(-) diff --git a/include/stochtree/category_tracker.h b/include/stochtree/category_tracker.h index 2ce44635..ba12ee3d 100644 --- a/include/stochtree/category_tracker.h +++ b/include/stochtree/category_tracker.h @@ -1,24 +1,24 @@ /*! * Copyright (c) 2024 stochtree authors. - * + * * General-purpose data structures used for keeping track of categories in a training dataset. - * - * SampleCategoryMapper is a simplified version of SampleNodeMapper, which is not tree-specific - * as it tracks categories loaded into a training dataset, and we do not expect to modify it during + * + * SampleCategoryMapper is a simplified version of SampleNodeMapper, which is not tree-specific + * as it tracks categories loaded into a training dataset, and we do not expect to modify it during * training. - * + * * SampleCategoryMapper is used in two places: * 1. Group random effects: mapping observations to group IDs for the purpose of computing random effects * 2. Heteroskedasticity based on fixed categories (as opposed to partitions as in HBART by Pratola et al 2018) - * - One example of this would be binary treatment causal inference with separate outcome variances + * - One example of this would be binary treatment causal inference with separate outcome variances * for the treated and control groups (as in Krantsevich et al 2023) - * - * CategorySampleTracker is a simplified version of FeatureUnsortedPartition, which as above does + * + * CategorySampleTracker is a simplified version of FeatureUnsortedPartition, which as above does * not vary based on tree / partition and is not expected to change during training. - * - * SampleNodeMapper is inspired by the design of the DataPartition class in LightGBM, + * + * SampleNodeMapper is inspired by the design of the DataPartition class in LightGBM, * released under the MIT license with the following copyright: - * + * * Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ @@ -44,8 +44,8 @@ class SampleCategoryMapper { num_observations_ = group_indices.size(); observation_indices_ = group_indices; } - - SampleCategoryMapper(SampleCategoryMapper& other){ + + SampleCategoryMapper(SampleCategoryMapper& other) { num_observations_ = other.NumObservations(); observation_indices_.resize(num_observations_); for (int i = 0; i < num_observations_; i++) { @@ -62,8 +62,8 @@ class SampleCategoryMapper { CHECK_LT(sample_id, num_observations_); observation_indices_[sample_id] = sample_id; } - - inline int NumObservations() {return num_observations_;} + + inline int NumObservations() { return num_observations_; } private: std::vector observation_indices_; @@ -80,16 +80,16 @@ class CategorySampleTracker { indices_ = std::vector(n); std::iota(indices_.begin(), indices_.end(), 0); - auto comp_op = [&](size_t const &l, size_t const &r) { return std::less{}(group_indices[l], group_indices[r]); }; + auto comp_op = [&](size_t const& l, size_t const& r) { return std::less{}(group_indices[l], group_indices[r]); }; std::stable_sort(indices_.begin(), indices_.end(), comp_op); category_count_ = 0; int observation_count = 0; for (int i = 0; i < n; i++) { bool start_cond = i == 0; - bool end_cond = i == n-1; + bool end_cond = i == n - 1; bool new_group_cond{false}; - if (i > 0) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i-1]]; + if (i > 0) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i - 1]]; if (start_cond || new_group_cond) { category_id_map_.insert({group_indices[indices_[i]], category_count_}); unique_category_ids_.push_back(group_indices[indices_[i]]); @@ -103,7 +103,7 @@ class CategorySampleTracker { observation_count = 1; category_count_++; } else if (end_cond) { - category_length_.push_back(observation_count+1); + category_length_.push_back(observation_count + 1); } else { observation_count++; } @@ -118,7 +118,7 @@ class CategorySampleTracker { } /*! \brief First index of data points contained in node_id */ - inline data_size_t CategoryBegin(int category_id) {return category_begin_[category_id_map_[category_id]];} + inline data_size_t CategoryBegin(int category_id) { return category_begin_[category_id_map_[category_id]]; } /*! \brief One past the last index of data points contained in node_id */ inline data_size_t CategoryEnd(int category_id) { @@ -132,7 +132,7 @@ class CategorySampleTracker { } /*! \brief Number of total categories stored */ - inline data_size_t NumCategories() {return category_count_;} + inline data_size_t NumCategories() { return category_count_; } /*! \brief Data indices */ std::vector indices_; @@ -142,16 +142,16 @@ class CategorySampleTracker { int32_t id = category_id_map_[category_id]; return node_index_vector_[id]; } - + /*! \brief Data indices for a given node */ std::vector& NodeIndicesInternalIndex(int internal_category_id) { return node_index_vector_[internal_category_id]; } /*! \brief Returns label index map */ - std::map& GetLabelMap() {return category_id_map_;} + std::map& GetLabelMap() { return category_id_map_; } - std::vector& GetUniqueGroupIds() {return unique_category_ids_;} + std::vector& GetUniqueGroupIds() { return unique_category_ids_; } private: // Vectors tracking indices in each node @@ -163,6 +163,6 @@ class CategorySampleTracker { int32_t category_count_; }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_CATEGORY_TRACKER_H_ +#endif // STOCHTREE_CATEGORY_TRACKER_H_ diff --git a/include/stochtree/common.h b/include/stochtree/common.h index cd57eea2..5598926a 100644 --- a/include/stochtree/common.h +++ b/include/stochtree/common.h @@ -53,9 +53,9 @@ namespace StochTree { namespace Common { /*! -* Imbues the stream with the C locale. -*/ -static void C_stringstream(std::stringstream &ss) { + * Imbues the stream with the C locale. + */ +static void C_stringstream(std::stringstream& ss) { ss.imbue(std::locale::classic()); } @@ -190,7 +190,7 @@ inline static std::vector Split(const char* c_str, const char* deli return ret; } -template +template inline static const char* Atoi(const char* p, T* out) { int sign; T value; @@ -214,16 +214,16 @@ inline static const char* Atoi(const char* p, T* out) { return p; } -template +template inline static double Pow(T base, int power) { if (power < 0) { return 1.0 / Pow(base, -power); } else if (power == 0) { return 1; } else if (power % 2 == 0) { - return Pow(base*base, power / 2); + return Pow(base * base, power / 2); } else if (power % 3 == 0) { - return Pow(base*base*base, power / 3); + return Pow(base * base * base, power / 3); } else { return base * Pow(base, power - 1); } @@ -285,18 +285,29 @@ inline static const char* Atof(const char* p, double* out) { } if (expon > 308) expon = 308; // Calculate scaling factor. - while (expon >= 50) { scale *= 1E50; expon -= 50; } - while (expon >= 8) { scale *= 1E8; expon -= 8; } - while (expon > 0) { scale *= 10.0; expon -= 1; } + while (expon >= 50) { + scale *= 1E50; + expon -= 50; + } + while (expon >= 8) { + scale *= 1E8; + expon -= 8; + } + while (expon > 0) { + scale *= 10.0; + expon -= 1; + } } // Return signed and scaled floating point result. *out = sign * (frac ? (value / scale) : (value * scale)); } else { size_t cnt = 0; + // clang-format off while (*(p + cnt) != '\0' && *(p + cnt) != ' ' && *(p + cnt) != '\t' && *(p + cnt) != ',' && *(p + cnt) != '\n' && *(p + cnt) != '\r' && *(p + cnt) != ':') { + // clang-format on ++cnt; } if (cnt > 0) { @@ -331,7 +342,7 @@ inline static const char* AtofPrecise(const char* p, double* out) { // Rare path: Not in RFC 7159 format. Possible "inf", "nan", etc. Fallback to standard library: char* end2; - errno = 0; // This is Required before calling strtod. + errno = 0; // This is Required before calling strtod. *out = std::strtod(p, &end2); // strtod is locale aware. if (end2 == p) { Log::Fatal("no conversion to double for: %s", p); @@ -372,7 +383,7 @@ inline static const char* SkipReturn(const char* p) { return p; } -template +template inline static std::vector ArrayCast(const std::vector& arr) { std::vector ret(arr.size()); for (size_t i = 0; i < arr.size(); ++i) { @@ -381,7 +392,7 @@ inline static std::vector ArrayCast(const std::vector& arr) { return ret; } -template +template struct __StringToTHelper { T operator()(const std::string& str) const { T ret = 0; @@ -390,14 +401,14 @@ struct __StringToTHelper { } }; -template +template struct __StringToTHelper { T operator()(const std::string& str) const { return static_cast(std::stod(str)); } }; -template +template inline static std::vector StringToArray(const std::string& str, char delimiter) { std::vector strs = Split(str.c_str(), delimiter); std::vector ret; @@ -409,7 +420,7 @@ inline static std::vector StringToArray(const std::string& str, char delimite return ret; } -template +template inline static std::vector> StringToArrayofArrays( const std::string& str, char left_bracket, char right_bracket, char delimiter) { std::vector strs = SplitBrackets(str.c_str(), left_bracket, right_bracket); @@ -420,7 +431,7 @@ inline static std::vector> StringToArrayofArrays( return ret; } -template +template inline static std::vector StringToArray(const std::string& str, int n) { if (n == 0) { return std::vector(); @@ -436,16 +447,16 @@ inline static std::vector StringToArray(const std::string& str, int n) { return ret; } -template +template struct __StringToTHelperFast { - const char* operator()(const char*p, T* out) const { + const char* operator()(const char* p, T* out) const { return Atoi(p, out); } }; -template +template struct __StringToTHelperFast { - const char* operator()(const char*p, T* out) const { + const char* operator()(const char* p, T* out) const { double tmp = 0.0f; auto ret = Atof(p, &tmp); *out = static_cast(tmp); @@ -453,7 +464,7 @@ struct __StringToTHelperFast { } }; -template +template inline static std::vector StringToArrayFast(const std::string& str, int n) { if (n == 0) { return std::vector(); @@ -467,7 +478,7 @@ inline static std::vector StringToArrayFast(const std::string& str, int n) { return ret; } -template +template inline static std::string Join(const std::vector& strs, const char* delimiter, const bool force_C_locale = false) { if (strs.empty()) { return std::string(""); @@ -485,7 +496,7 @@ inline static std::string Join(const std::vector& strs, const char* delimiter return str_buf.str(); } -template<> +template <> inline std::string Join(const std::vector& strs, const char* delimiter, const bool force_C_locale) { if (strs.empty()) { return std::string(""); @@ -503,7 +514,7 @@ inline std::string Join(const std::vector& strs, const char* del return str_buf.str(); } -template +template inline static std::string Join(const std::vector& strs, size_t start, size_t end, const char* delimiter, const bool force_C_locale = false) { if (end - start <= 0) { return std::string(""); @@ -539,7 +550,7 @@ inline static int64_t Pow2RoundUp(int64_t x) { * \param p_rec The input/output vector of the values. */ inline static void Softmax(std::vector* p_rec) { - std::vector &rec = *p_rec; + std::vector& rec = *p_rec; double wmax = rec[0]; for (size_t i = 1; i < rec.size(); ++i) { wmax = std::max(rec[i], wmax); @@ -569,16 +580,16 @@ inline static void Softmax(const double* input, double* output, int len) { } } -template +template std::vector ConstPtrInVectorWrapper(const std::vector>& input) { std::vector ret; - for (auto t = input.begin(); t !=input.end(); ++t) { + for (auto t = input.begin(); t != input.end(); ++t) { ret.push_back(t->get()); } return ret; } -template +template inline static void SortForPair(std::vector* keys, std::vector* values, size_t start, bool is_reverse = false) { std::vector> arr; auto& ref_key = *keys; @@ -644,14 +655,14 @@ inline static float AvoidInf(float x) { } } -template inline -static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) { +template +inline static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) { return (0); } // Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not template -inline static void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) { +inline static void CheckElementsIntervalClosed(const T* y, T ymin, T ymax, int ny, const char* callername) { auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) { std::ostringstream os; os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]"; @@ -682,7 +693,7 @@ inline static void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int n // One-pass scan over array w with nw elements: find min, max and sum of elements; // this is useful for checking weight requirements. template -inline static void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) { +inline static void ObtainMinMaxSum(const T1* w, int nw, T1* mi, T1* ma, T2* su) { T1 minw; T1 maxw; T1 sumw; @@ -730,7 +741,7 @@ inline static std::vector EmptyBitset(int n) { return std::vector(size); } -template +template inline static void InsertBitset(std::vector* vec, const T val) { auto& ref_v = *vec; int i1 = val / 32; @@ -741,7 +752,7 @@ inline static void InsertBitset(std::vector* vec, const T val) { ref_v[i1] |= (1 << i2); } -template +template inline static std::vector ConstructBitset(const T* vals, int n) { std::vector ret; for (int i = 0; i < n; ++i) { @@ -755,7 +766,7 @@ inline static std::vector ConstructBitset(const T* vals, int n) { return ret; } -template +template inline static bool FindInBitset(const uint32_t* bits, int n, T pos) { int i1 = pos / 32; if (i1 >= n) { @@ -817,7 +828,7 @@ inline bool CheckAllowedJSON(const std::string& s) { || char_code == 93 // ] || char_code == 123 // { || char_code == 125 // } - ) { + ) { return false; } } @@ -979,19 +990,18 @@ class FunctionTimer { extern Common::Timer global_timer; - /*! -* Provides locale-independent alternatives to Common's methods. -* Essential to make models robust to locale settings. -*/ + * Provides locale-independent alternatives to Common's methods. + * Essential to make models robust to locale settings. + */ namespace CommonC { -template +template inline static std::string Join(const std::vector& strs, const char* delimiter) { return StochTree::Common::Join(strs, delimiter, true); } -template +template inline static std::string Join(const std::vector& strs, size_t start, size_t end, const char* delimiter) { return StochTree::Common::Join(strs, start, end, delimiter, true); } @@ -1000,22 +1010,22 @@ inline static const char* Atof(const char* p, double* out) { return StochTree::Common::Atof(p, out); } -template +template struct __StringToTHelperFast { - const char* operator()(const char*p, T* out) const { + const char* operator()(const char* p, T* out) const { return StochTree::Common::Atoi(p, out); } }; /*! -* \warning Beware that ``Common::Atof`` in ``__StringToTHelperFast``, -* has **less** floating point precision than ``__StringToTHelper``. -* Both versions are kept to maintain bit-for-bit the "legacy" LightGBM behaviour in terms of precision. -* Check ``StringToArrayFast`` and ``StringToArray`` for more details on this. -*/ -template + * \warning Beware that ``Common::Atof`` in ``__StringToTHelperFast``, + * has **less** floating point precision than ``__StringToTHelper``. + * Both versions are kept to maintain bit-for-bit the "legacy" LightGBM behaviour in terms of precision. + * Check ``StringToArrayFast`` and ``StringToArray`` for more details on this. + */ +template struct __StringToTHelperFast { - const char* operator()(const char*p, T* out) const { + const char* operator()(const char* p, T* out) const { double tmp = 0.0f; auto ret = Atof(p, &tmp); *out = static_cast(tmp); @@ -1023,7 +1033,7 @@ struct __StringToTHelperFast { } }; -template +template struct __StringToTHelper { T operator()(const std::string& str) const { T ret = 0; @@ -1033,35 +1043,34 @@ struct __StringToTHelper { }; /*! -* \warning Beware that ``Common::Atof`` in ``__StringToTHelperFast``, -* has **less** floating point precision than ``__StringToTHelper``. -* Both versions are kept to maintain bit-for-bit the "legacy" LightGBM behaviour in terms of precision. -* Check ``StringToArrayFast`` and ``StringToArray`` for more details on this. -* \note It is possible that ``fast_double_parser::parse_number`` is faster than ``Common::Atof``. -*/ -template + * \warning Beware that ``Common::Atof`` in ``__StringToTHelperFast``, + * has **less** floating point precision than ``__StringToTHelper``. + * Both versions are kept to maintain bit-for-bit the "legacy" LightGBM behaviour in terms of precision. + * Check ``StringToArrayFast`` and ``StringToArray`` for more details on this. + * \note It is possible that ``fast_double_parser::parse_number`` is faster than ``Common::Atof``. + */ +template struct __StringToTHelper { T operator()(const std::string& str) const { double tmp; const char* end = Common::AtofPrecise(str.c_str(), &tmp); if (end == str.c_str()) { - Log::Fatal("Failed to parse double: %s", str.c_str()); + Log::Fatal("Failed to parse double: %s", str.c_str()); } return static_cast(tmp); } }; - /*! -* \warning Beware that due to internal use of ``Common::Atof`` in ``__StringToTHelperFast``, -* this method has less precision for floating point numbers than ``StringToArray``, -* which calls ``__StringToTHelper``. -* As such, ``StringToArrayFast`` and ``StringToArray`` are not equivalent! -* Both versions were kept to maintain bit-for-bit the "legacy" LightGBM behaviour in terms of precision. -*/ -template + * \warning Beware that due to internal use of ``Common::Atof`` in ``__StringToTHelperFast``, + * this method has less precision for floating point numbers than ``StringToArray``, + * which calls ``__StringToTHelper``. + * As such, ``StringToArrayFast`` and ``StringToArray`` are not equivalent! + * Both versions were kept to maintain bit-for-bit the "legacy" LightGBM behaviour in terms of precision. + */ +template inline static std::vector StringToArrayFast(const std::string& str, int n) { if (n == 0) { return std::vector(); @@ -1076,11 +1085,11 @@ inline static std::vector StringToArrayFast(const std::string& str, int n) { } /*! -* \warning Do not replace calls to this method by ``StringToArrayFast``. -* This method is more precise for floating point numbers. -* Check ``StringToArrayFast`` for more details. -*/ -template + * \warning Do not replace calls to this method by ``StringToArrayFast``. + * This method is more precise for floating point numbers. + * Check ``StringToArrayFast`` for more details. + */ +template inline static std::vector StringToArray(const std::string& str, int n) { if (n == 0) { return std::vector(); @@ -1097,11 +1106,11 @@ inline static std::vector StringToArray(const std::string& str, int n) { } /*! -* \warning Do not replace calls to this method by ``StringToArrayFast``. -* This method is more precise for floating point numbers. -* Check ``StringToArrayFast`` for more details. -*/ -template + * \warning Do not replace calls to this method by ``StringToArrayFast``. + * This method is more precise for floating point numbers. + * Check ``StringToArrayFast`` for more details. + */ +template inline static std::vector StringToArray(const std::string& str, char delimiter) { std::vector strs = StochTree::Common::Split(str.c_str(), delimiter); std::vector ret; @@ -1114,37 +1123,37 @@ inline static std::vector StringToArray(const std::string& str, char delimite } /*! -* Safely formats a value onto a buffer according to a format string and null-terminates it. -* -* \note It checks that the full value was written or forcefully aborts. -* This safety check serves to prevent incorrect internal API usage. -* Correct usage will never incur in this problem: -* - The received buffer size shall be sufficient at all times for the input format string and value. -*/ + * Safely formats a value onto a buffer according to a format string and null-terminates it. + * + * \note It checks that the full value was written or forcefully aborts. + * This safety check serves to prevent incorrect internal API usage. + * Correct usage will never incur in this problem: + * - The received buffer size shall be sufficient at all times for the input format string and value. + */ template inline static void format_to_buf(char* buffer, const size_t buf_len, const char* format, const T value) { - auto result = fmt::format_to_n(buffer, buf_len, format, value); - if (result.size >= buf_len) { - Log::Fatal("Numerical conversion failed. Buffer is too small."); - } - buffer[result.size] = '\0'; + auto result = fmt::format_to_n(buffer, buf_len, format, value); + if (result.size >= buf_len) { + Log::Fatal("Numerical conversion failed. Buffer is too small."); + } + buffer[result.size] = '\0'; } -template +template struct __TToStringHelper { void operator()(T value, char* buffer, size_t buf_len) const { format_to_buf(buffer, buf_len, "{}", value); } }; -template +template struct __TToStringHelper { void operator()(T value, char* buffer, size_t buf_len) const { format_to_buf(buffer, buf_len, "{:g}", value); } }; -template +template struct __TToStringHelper { void operator()(T value, char* buffer, size_t buf_len) const { format_to_buf(buffer, buf_len, "{:.17g}", value); @@ -1152,14 +1161,14 @@ struct __TToStringHelper { }; /*! -* Converts an array to a string with with values separated by the space character. -* This method replaces Common's ``ArrayToString`` and ``ArrayToStringFast`` functionality -* and is locale-independent. -* -* \note If ``high_precision_output`` is set to true, -* floating point values are output with more digits of precision. -*/ -template + * Converts an array to a string with with values separated by the space character. + * This method replaces Common's ``ArrayToString`` and ``ArrayToStringFast`` functionality + * and is locale-independent. + * + * \note If ``high_precision_output`` is set to true, + * floating point values are output with more digits of precision. + */ +template inline static std::string ArrayToString(const std::vector& arr, size_t n) { if (arr.empty() || n == 0) { return std::string(""); diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 4b75ef2f..8840bbbc 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -1,6 +1,6 @@ /*! * Copyright (c) 2024 stochtree authors. All rights reserved. - * + * * Simple container-like interfaces for samples of common models. */ #ifndef STOCHTREE_CONTAINER_H_ @@ -23,8 +23,8 @@ namespace StochTree { class ForestContainer { public: /*! - * \brief Construct a new ForestContainer object. - * + * \brief Construct a new ForestContainer object. + * * \param num_trees Number of trees in each forest. * \param output_dimension Dimension of the leaf node parameter in each tree of each forest. * \param is_leaf_constant Whether or not the leaves of each tree are treated as "constant." If true, then predicting from an ensemble is simply a matter or determining which leaf node an observation falls into. If false, prediction will multiply a leaf node's parameter(s) for a given observation by a basis vector. @@ -33,7 +33,7 @@ class ForestContainer { ForestContainer(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false); /*! * \brief Construct a new ForestContainer object. - * + * * \param num_samples Initial size of a container of forest samples. * \param num_trees Number of trees in each forest. * \param output_dimension Dimension of the leaf node parameter in each tree of each forest. @@ -44,7 +44,7 @@ class ForestContainer { ~ForestContainer() {} /*! * \brief Combine two forests into a single forest by merging their trees - * + * * \param inbound_forest_index Index of the forest that will be appended to * \param outbound_forest_index Index of the forest that will be appended */ @@ -53,7 +53,7 @@ class ForestContainer { } /*! * \brief Add a constant value to every leaf of every tree of a specified forest - * + * * \param forest_index Index of forest whose leaves will be modified * \param constant_value Value to add to every leaf of every tree of the forest at `forest_index` */ @@ -62,7 +62,7 @@ class ForestContainer { } /*! * \brief Multiply every leaf of every tree of a specified forest by a constant value - * + * * \param forest_index Index of forest whose leaves will be modified * \param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index` */ @@ -71,62 +71,62 @@ class ForestContainer { } /*! * \brief Remove a forest from a container of forest samples and delete the corresponding object, freeing its memory. - * + * * \param sample_num Index of forest to be deleted. */ void DeleteSample(int sample_num); /*! * \brief Add a new forest to the container by copying `forest`. - * + * * \param forest Forest to be copied and added to the container of retained forest samples. */ void AddSample(TreeEnsemble& forest); /*! * \brief Initialize a "root" forest of univariate trees as the first element of the container, setting all root node values in every tree to `leaf_value`. - * + * * \param leaf_value Value to assign to the root node of every tree. */ void InitializeRoot(double leaf_value); /*! * \brief Initialize a "root" forest of multivariate trees as the first element of the container, setting all root node values in every tree to `leaf_vector`. - * + * * \param leaf_value Vector of values to assign to the root node of every tree. */ void InitializeRoot(std::vector& leaf_vector); /*! * \brief Pre-allocate space for `num_samples` additional forests in the container. - * + * * \param num_samples Number of (default-constructed) forests to allocated space for in the container. */ void AddSamples(int num_samples); /*! * \brief Copy the forest stored at `previous_sample_id` to the forest stored at `new_sample_id`. - * + * * \param new_sample_id Index of the new forest to be copied from an earlier sample. * \param previous_sample_id Index of the previous forest to copy to `new_sample_id`. */ void CopyFromPreviousSample(int new_sample_id, int previous_sample_id); /*! - * \brief Predict from every forest in the container on every observation in the provided dataset. - * The resulting vector is "column-major", where every forest in a container defines the columns of a - * prediction matrix and every observation in the provided dataset defines the rows. The (`i`,`j`) element - * of this prediction matrix can be read from the `j * num_rows + i` element of the returned `std::vector`, + * \brief Predict from every forest in the container on every observation in the provided dataset. + * The resulting vector is "column-major", where every forest in a container defines the columns of a + * prediction matrix and every observation in the provided dataset defines the rows. The (`i`,`j`) element + * of this prediction matrix can be read from the `j * num_rows + i` element of the returned `std::vector`, * where `num_rows` is equal to the number of observations in `dataset` (i.e. `dataset.NumObservations()`). - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. * \return std::vector Vector of predictions for every forest in the container and every observation in `dataset`. */ std::vector Predict(ForestDataset& dataset); /*! - * \brief Predict from every forest in the container on every observation in the provided dataset. - * The resulting vector stores a possibly three-dimensional array, where the dimensions are arranged as follows - * + * \brief Predict from every forest in the container on every observation in the provided dataset. + * The resulting vector stores a possibly three-dimensional array, where the dimensions are arranged as follows + * * 1. Dimension of the leaf node's raw values (1 for GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, and LogLinearVarianceLeafModel, >1 for GaussianMultivariateRegressionLeafModel) * 2. Observations in the provided dataset. * 3. Forest samples in the container. - * + * * If the leaf nodes have univariate values, then the "first dimension" is 1 and the resulting array has the exact same layout as in \ref Predict. - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. * \return std::vector Vector of predictions for every forest in the container and every observation in `dataset`. */ @@ -137,17 +137,17 @@ class ForestContainer { void PredictRawInPlace(ForestDataset& dataset, std::vector& output); void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector& output); void PredictRawSingleTreeInPlace(ForestDataset& dataset, int forest_num, int tree_num, std::vector& output); - void PredictLeafIndicesInplace(Eigen::Map>& covariates, - Eigen::Map>& output, + void PredictLeafIndicesInplace(Eigen::Map>& covariates, + Eigen::Map>& output, std::vector& forest_indices, int num_trees, data_size_t n); - inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();} - inline int32_t NumSamples() {return num_samples_;} - inline int32_t NumTrees() {return num_trees_;} - inline int32_t NumTrees(int ensemble_num) {return forests_[ensemble_num]->NumTrees();} - inline int32_t NumLeaves(int ensemble_num) {return forests_[ensemble_num]->NumLeaves();} - inline int32_t EnsembleTreeMaxDepth(int ensemble_num, int tree_num) {return forests_[ensemble_num]->TreeMaxDepth(tree_num);} - inline double EnsembleAverageMaxDepth(int ensemble_num) {return forests_[ensemble_num]->AverageMaxDepth();} + inline TreeEnsemble* GetEnsemble(int i) { return forests_[i].get(); } + inline int32_t NumSamples() { return num_samples_; } + inline int32_t NumTrees() { return num_trees_; } + inline int32_t NumTrees(int ensemble_num) { return forests_[ensemble_num]->NumTrees(); } + inline int32_t NumLeaves(int ensemble_num) { return forests_[ensemble_num]->NumLeaves(); } + inline int32_t EnsembleTreeMaxDepth(int ensemble_num, int tree_num) { return forests_[ensemble_num]->TreeMaxDepth(tree_num); } + inline double EnsembleAverageMaxDepth(int ensemble_num) { return forests_[ensemble_num]->AverageMaxDepth(); } inline double AverageMaxDepth() { double numerator = 0.; double denominator = 0.; @@ -159,23 +159,23 @@ class ForestContainer { } return numerator / denominator; } - inline int32_t OutputDimension() {return output_dimension_;} - inline int32_t OutputDimension(int ensemble_num) {return forests_[ensemble_num]->OutputDimension();} - inline bool IsLeafConstant() {return is_leaf_constant_;} - inline bool IsLeafConstant(int ensemble_num) {return forests_[ensemble_num]->IsLeafConstant();} - inline bool IsExponentiated() {return is_exponentiated_;} - inline bool IsExponentiated(int ensemble_num) {return forests_[ensemble_num]->IsExponentiated();} - inline bool AllRoots(int ensemble_num) {return forests_[ensemble_num]->AllRoots();} - inline void SetLeafValue(int ensemble_num, double leaf_value) {forests_[ensemble_num]->SetLeafValue(leaf_value);} - inline void SetLeafVector(int ensemble_num, std::vector& leaf_vector) {forests_[ensemble_num]->SetLeafVector(leaf_vector);} - inline void IncrementSampleCount() {num_samples_++;} + inline int32_t OutputDimension() { return output_dimension_; } + inline int32_t OutputDimension(int ensemble_num) { return forests_[ensemble_num]->OutputDimension(); } + inline bool IsLeafConstant() { return is_leaf_constant_; } + inline bool IsLeafConstant(int ensemble_num) { return forests_[ensemble_num]->IsLeafConstant(); } + inline bool IsExponentiated() { return is_exponentiated_; } + inline bool IsExponentiated(int ensemble_num) { return forests_[ensemble_num]->IsExponentiated(); } + inline bool AllRoots(int ensemble_num) { return forests_[ensemble_num]->AllRoots(); } + inline void SetLeafValue(int ensemble_num, double leaf_value) { forests_[ensemble_num]->SetLeafValue(leaf_value); } + inline void SetLeafVector(int ensemble_num, std::vector& leaf_vector) { forests_[ensemble_num]->SetLeafVector(leaf_vector); } + inline void IncrementSampleCount() { num_samples_++; } void SaveToJsonFile(std::string filename) { nlohmann::json model_json = this->to_json(); std::ofstream output_file(filename); output_file << model_json << std::endl; } - + void LoadFromJsonFile(std::string filename) { std::ifstream f(filename); nlohmann::json file_tree_json = nlohmann::json::parse(f); @@ -219,6 +219,6 @@ class ForestContainer { bool is_leaf_constant_; bool initialized_{false}; }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_CONTAINER_H_ +#endif // STOCHTREE_CONTAINER_H_ diff --git a/include/stochtree/cutpoint_candidates.h b/include/stochtree/cutpoint_candidates.h index 76f1df4c..a4503c19 100644 --- a/include/stochtree/cutpoint_candidates.h +++ b/include/stochtree/cutpoint_candidates.h @@ -1,39 +1,39 @@ /*! * Copyright (c) 2024 stochtree authors. - * + * * Data structures for enumerating potential cutpoint candidates. - * + * * This is used in the XBART family of algorithms, which samples split rules - * based on the log marginal likelihood of every potential cutpoint. For numeric - * variables with large sample sizes, it is often unnecessary to consider every + * based on the log marginal likelihood of every potential cutpoint. For numeric + * variables with large sample sizes, it is often unnecessary to consider every * unique value, so we allow for an adaptive "grid" of potential cutpoint values. - * - * Algorithms for enumerating cutpoints take Dataset and SortedNodeSampleTracker objects - * as inputs, so that each feature is "pre-sorted" according to its value within a - * given node. The size of the adaptive cutpoint grid is set by the + * + * Algorithms for enumerating cutpoints take Dataset and SortedNodeSampleTracker objects + * as inputs, so that each feature is "pre-sorted" according to its value within a + * given node. The size of the adaptive cutpoint grid is set by the * cutpoint_grid_size configuration parameter. - * + * * Numeric Features * ---------------- - * - * When a node has fewer available observations than cutpoint_grid_size, - * full enumeration of unique available cutpoints is done via the + * + * When a node has fewer available observations than cutpoint_grid_size, + * full enumeration of unique available cutpoints is done via the * `EnumerateNumericCutpointsDeduplication` function - * - * When a node has more available observations than cutpoint_grid_size, - * potential cutpoints are "thinned out" by considering every k-th observation, + * + * When a node has more available observations than cutpoint_grid_size, + * potential cutpoints are "thinned out" by considering every k-th observation, * where k is implied by the number of observations and the target cutpoint_grid_size. - * + * * Ordered Categorical Features * ---------------------------- - * - * In this case, the grid is every unique value of the ordered categorical + * + * In this case, the grid is every unique value of the ordered categorical * feature in ascending order. - * + * * Unordered Categorical Features * ------------------------------ - * - * In this case, the grid is every unique value of the unordered categorical feature, + * + * In this case, the grid is every unique value of the unordered categorical feature, * arranged in an outcome-dependent order, as described in Fisher (1958) */ #ifndef STOCHTREE_CUTPOINT_CANDIDATES_H_ @@ -45,7 +45,7 @@ namespace StochTree { /*! \brief Computing and tracking cutpoints available for a given feature at a given node - * Store cutpoint bins in 0-indexed fashion, so that if a given node has + * Store cutpoint bins in 0-indexed fashion, so that if a given node has */ class FeatureCutpointGrid { public: @@ -66,19 +66,19 @@ class FeatureCutpointGrid { void CalculateStridesUnorderedCategorical(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); /*! \brief Number of potential cutpoints enumerated */ - int32_t NumCutpoints() {return node_stride_begin_.size();} + int32_t NumCutpoints() { return node_stride_begin_.size(); } /*! \brief Beginning index of bin i */ - int32_t BinStartIndex(int i) {return node_stride_begin_.at(i);} + int32_t BinStartIndex(int i) { return node_stride_begin_.at(i); } /*! \brief Size of bin i */ - int32_t BinLength(int i) {return node_stride_length_.at(i);} + int32_t BinLength(int i) { return node_stride_length_.at(i); } /*! \brief Beginning index of bin i */ - int32_t BinEndIndex(int i) {return node_stride_begin_.at(i) + node_stride_length_.at(i);} + int32_t BinEndIndex(int i) { return node_stride_begin_.at(i) + node_stride_length_.at(i); } /*! \brief Value of the upper-bound (cutpoint) implied by bin i */ - double CutpointValue(int i) {return cutpoint_values_.at(i);} + double CutpointValue(int i) { return cutpoint_values_.at(i); } /*! \brief Vector of cutpoint values up to and including bin i * Helper function for converting categorical split "value" (as outlined in Fisher 1958) to a set of categories @@ -135,22 +135,22 @@ class CutpointGridContainer { } /*! \brief Max size of cutpoint grid */ - int32_t CutpointGridSize() {return cutpoint_grid_size_;} + int32_t CutpointGridSize() { return cutpoint_grid_size_; } /*! \brief Number of potential cutpoints enumerated */ - int32_t NumCutpoints(int feature_index) {return feature_cutpoint_grid_[feature_index]->NumCutpoints();} + int32_t NumCutpoints(int feature_index) { return feature_cutpoint_grid_[feature_index]->NumCutpoints(); } /*! \brief Beginning index of bin i */ - int32_t BinStartIndex(int i, int feature_index) {return feature_cutpoint_grid_[feature_index]->BinStartIndex(i);} + int32_t BinStartIndex(int i, int feature_index) { return feature_cutpoint_grid_[feature_index]->BinStartIndex(i); } /*! \brief Size of bin i */ - int32_t BinLength(int i, int feature_index) {return feature_cutpoint_grid_[feature_index]->BinLength(i);} + int32_t BinLength(int i, int feature_index) { return feature_cutpoint_grid_[feature_index]->BinLength(i); } /*! \brief Beginning index of bin i */ - int32_t BinEndIndex(int i, int feature_index) {return feature_cutpoint_grid_[feature_index]->BinEndIndex(i);} + int32_t BinEndIndex(int i, int feature_index) { return feature_cutpoint_grid_[feature_index]->BinEndIndex(i); } /*! \brief Value of the upper-bound (cutpoint) implied by bin i */ - double CutpointValue(int i, int feature_index) {return feature_cutpoint_grid_[feature_index]->CutpointValue(i);} + double CutpointValue(int i, int feature_index) { return feature_cutpoint_grid_[feature_index]->CutpointValue(i); } /*! \brief Vector of cutpoint values up to and including bin i * Helper function for converting categorical split "value" (as outlined in Fisher 1958) to a set of categories @@ -159,7 +159,7 @@ class CutpointGridContainer { return feature_cutpoint_grid_[feature_index]->CutpointVector(i); } - FeatureCutpointGrid* GetFeatureCutpointGrid(int feature_num) {return feature_cutpoint_grid_[feature_num].get(); } + FeatureCutpointGrid* GetFeatureCutpointGrid(int feature_num) { return feature_cutpoint_grid_[feature_num].get(); } private: std::vector> feature_cutpoint_grid_; @@ -184,7 +184,7 @@ class NodeCutpointTracker { void CalculateStridesCategorical(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, data_size_t node_begin, data_size_t node_end, int32_t feature_index); /*! \brief Number of potential cutpoints enumerated */ - int32_t NumCutpoints() {return node_stride_begin_.size();} + int32_t NumCutpoints() { return node_stride_begin_.size(); } /*! \brief Whether a cutpoint grid has been enumerated for a given node */ bool NodeCutpointEvaluated(int32_t node_id) { @@ -192,18 +192,18 @@ class NodeCutpointTracker { } /*! \brief Node id of the node that has been most recently evaluated */ - int32_t CurrentNodeEvaluated() {return current_node_;} + int32_t CurrentNodeEvaluated() { return current_node_; } /*! \brief Vectors of node stride starting points and stride lengths */ std::vector node_stride_begin_; std::vector node_stride_length_; - + private: int32_t cutpoint_grid_size_; std::vector nodes_enumerated_; int32_t current_node_; }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_CUTPOINT_CANDIDATES_H_ +#endif // STOCHTREE_CUTPOINT_CANDIDATES_H_ diff --git a/include/stochtree/data.h b/include/stochtree/data.h index 393203b1..e81dc17b 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -21,10 +21,10 @@ namespace StochTree { */ /*! - * \brief Extract multiple features from the raw data loaded from a file into an `Eigen::MatrixXd`. + * \brief Extract multiple features from the raw data loaded from a file into an `Eigen::MatrixXd`. * Lightly modified from LightGBM's datasetloader interface to support `stochtree`'s use cases. * \internal - * + * * \param text_data Vector of data reads as string from a file. * \param parser Pointer to a parser object (i.e. `CSVParser`). * \param column_indices Integer labels of columns to be extracted from `text_data` into `data`. @@ -42,7 +42,7 @@ static inline void ExtractMultipleFeaturesFromMemory(std::vector* t // unpack the vector of textlines read from file into a vector of (int, double) tuples oneline_features.clear(); parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features); - + // free processed line: ref_text_data[i].clear(); @@ -50,9 +50,8 @@ static inline void ExtractMultipleFeaturesFromMemory(std::vector* t int feature_counter = 0; for (auto& inner_data : oneline_features) { int feature_idx = inner_data.first; - column_matched = (std::find(column_indices.begin(), column_indices.end(), feature_idx) - != column_indices.end()); - if (column_matched){ + column_matched = (std::find(column_indices.begin(), column_indices.end(), feature_idx) != column_indices.end()); + if (column_matched) { data(i, feature_counter) = inner_data.second; feature_counter += 1; } @@ -63,10 +62,10 @@ static inline void ExtractMultipleFeaturesFromMemory(std::vector* t } /*! -* \brief Extract a single feature from the raw data loaded from a file into an `Eigen::VectorXd`. + * \brief Extract a single feature from the raw data loaded from a file into an `Eigen::VectorXd`. * Lightly modified from LightGBM's datasetloader interface to support `stochtree`'s use cases. * \internal - * + * * \param text_data Vector of data reads as string from a file. * \param parser Pointer to a parser object (i.e. `CSVParser`). * \param column_index Integer labels of columns to be extracted from `text_data` into `data`. @@ -82,14 +81,14 @@ static inline void ExtractSingleFeatureFromMemory(std::vector* text // unpack the vector of textlines read from file into a vector of (int, double) tuples oneline_features.clear(); parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features); - + // free processed line: ref_text_data[i].clear(); // unload the data from oneline_features vector into the dataset variables containers for (auto& inner_data : oneline_features) { int feature_idx = inner_data.first; - if (column_index == feature_idx){ + if (column_index == feature_idx) { data(i) = inner_data.second; } } @@ -134,7 +133,7 @@ class ColumnMatrix { ColumnMatrix() {} /*! * \brief Construct a new `ColumnMatrix` object from in-memory data buffer. - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a matrix. * \param num_row Number of rows in the matrix. * \param num_col Number of columns / covariates in the matrix. @@ -143,7 +142,7 @@ class ColumnMatrix { ColumnMatrix(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major); /*! * \brief Construct a new ColumnMatrix object from CSV file - * + * * \param filename Name of the file (including any necessary path prefixes). * \param column_index_string Comma-delimited string listing columns to extract into covariates matrix. * \param header Whether or not the file contains a header of column names / non-data. @@ -153,11 +152,11 @@ class ColumnMatrix { ~ColumnMatrix() {} /*! * \brief Returns the value stored at (`row`, `col`) in the object's internal `Eigen::MatrixXd`. - * + * * \param row Row number to query in the matrix * \param col Column number to query in the matrix */ - double GetElement(data_size_t row_num, int32_t col_num) {return data_(row_num, col_num);} + double GetElement(data_size_t row_num, int32_t col_num) { return data_(row_num, col_num); } /*! * \brief Update an observation in the object's internal `Eigen::MatrixXd` to a new value. * @@ -165,10 +164,10 @@ class ColumnMatrix { * \param col Column number to be overwritten. * \param value New value to write in (`row`, `col`) in the object's internal `Eigen::MatrixXd`. */ - void SetElement(data_size_t row_num, int32_t col_num, double value) {data_(row_num, col_num) = value;} + void SetElement(data_size_t row_num, int32_t col_num, double value) { data_(row_num, col_num) = value; } /*! * \brief Update the data in a `ColumnMatrix` object from an in-memory data buffer. This will erase the existing matrix. - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a matrix. * \param num_row Number of rows in the matrix. * \param num_col Number of columns / covariates in the matrix. @@ -176,18 +175,19 @@ class ColumnMatrix { */ void LoadData(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major); /*! \brief Number of rows in the object's internal `Eigen::MatrixXd`. */ - inline data_size_t NumRows() {return data_.rows();} + inline data_size_t NumRows() { return data_.rows(); } /*! \brief Number of columns in the object's internal `Eigen::MatrixXd`. */ - inline int NumCols() {return data_.cols();} + inline int NumCols() { return data_.cols(); } /*! \brief Return a reference to the object's internal `Eigen::MatrixXd`, for interfaces that require a raw matrix. */ - inline Eigen::MatrixXd& GetData() {return data_;} + inline Eigen::MatrixXd& GetData() { return data_; } + private: Eigen::MatrixXd data_; }; /*! - * \brief Internal wrapper around `Eigen::VectorXd` interface for univariate floating point data. - * The (frequently updated) full / partial residual used in sampling forests is stored internally + * \brief Internal wrapper around `Eigen::VectorXd` interface for univariate floating point data. + * The (frequently updated) full / partial residual used in sampling forests is stored internally * as a `ColumnVector` by the sampling functions (see \ref sampling_group). */ class ColumnVector { @@ -195,14 +195,14 @@ class ColumnVector { ColumnVector() {} /*! * \brief Construct a new `ColumnVector` object from in-memory data buffer. - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. * \param num_row Number of rows / elements in the vector. */ ColumnVector(double* data_ptr, data_size_t num_row); /*! * \brief Construct a new ColumnMatrix object from CSV file - * + * * \param filename Name of the file (including any necessary path prefixes). * \param column_index Integer index of the column in `filename` to be unpacked as a vector. * \param header Whether or not the file contains a header of column names / non-data. @@ -212,60 +212,61 @@ class ColumnVector { ~ColumnVector() {} /*! * \brief Returns the value stored at position `row` in the object's internal `Eigen::VectorXd`. - * + * * \param row Row number to query in the vector */ - double GetElement(data_size_t row) {return data_(row);} + double GetElement(data_size_t row) { return data_(row); } /*! * \brief Returns the value stored at position `row` in the object's internal `Eigen::VectorXd`. - * + * * \param row Row number to query in the vector * \param value New value to write to element `row` of the object's internal `Eigen::VectorXd`. */ - void SetElement(data_size_t row, double value) {data_(row) = value;} + void SetElement(data_size_t row, double value) { data_(row) = value; } /*! * \brief Update the data in a `ColumnVector` object from an in-memory data buffer. This will erase the existing vector. - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. * \param num_row Number of rows / elements in the vector. */ void LoadData(double* data_ptr, data_size_t num_row); /*! - * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by adding each value obtained + * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by adding each value obtained * in `data_ptr` to the existing values in the object's internal `Eigen::VectorXd`. - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. * \param num_row Number of rows / elements in the vector. */ void AddToData(double* data_ptr, data_size_t num_row); /*! - * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by subtracting each value obtained + * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by subtracting each value obtained * in `data_ptr` from the existing values in the object's internal `Eigen::VectorXd`. - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. * \param num_row Number of rows / elements in the vector. */ void SubtractFromData(double* data_ptr, data_size_t num_row); /*! - * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by substituting each value obtained + * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by substituting each value obtained * in `data_ptr` for the existing values in the object's internal `Eigen::VectorXd`. - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. * \param num_row Number of rows / elements in the vector. */ void OverwriteData(double* data_ptr, data_size_t num_row); /*! \brief Number of rows in the object's internal `Eigen::VectorXd`. */ - inline data_size_t NumRows() {return data_.size();} + inline data_size_t NumRows() { return data_.size(); } /*! \brief Return a reference to the object's internal `Eigen::VectorXd`, for interfaces that require a raw vector. */ - inline Eigen::VectorXd& GetData() {return data_;} + inline Eigen::VectorXd& GetData() { return data_; } + private: Eigen::VectorXd data_; void UpdateData(double* data_ptr, data_size_t num_row, std::function op); }; -/*! - * \brief API for loading and accessing data used to sample tree ensembles - * The covariates / bases / weights used in sampling forests are stored internally +/*! + * \brief API for loading and accessing data used to sample tree ensembles + * The covariates / bases / weights used in sampling forests are stored internally * as a `ForestDataset` by the sampling functions (see \ref sampling_group). */ class ForestDataset { @@ -275,7 +276,7 @@ class ForestDataset { ~ForestDataset() {} /*! * \brief Copy / load covariates from raw memory buffer (often pointer to data in a R matrix or numpy array) - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a covariate matrix * \param num_row Number of rows in the covariate matrix * \param num_col Number of columns / covariates in the covariate matrix @@ -289,7 +290,7 @@ class ForestDataset { } /*! * \brief Copy / load basis matrix from raw memory buffer (often pointer to data in a R matrix or numpy array) - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix * \param num_row Number of rows in the basis matrix * \param num_col Number of columns in the basis matrix @@ -302,7 +303,7 @@ class ForestDataset { } /*! * \brief Copy / load variance weights from raw memory buffer (often pointer to data in a R vector or numpy array) - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing weights * \param num_row Number of rows in the weight vector */ @@ -312,7 +313,7 @@ class ForestDataset { } /*! * \brief Copy / load covariates from CSV file - * + * * \param filename Name of the file (including any necessary path prefixes) * \param column_index_string Comma-delimited string listing columns to extract into covariates matrix */ @@ -324,7 +325,7 @@ class ForestDataset { } /*! * \brief Copy / load basis matrix from CSV file - * + * * \param filename Name of the file (including any necessary path prefixes) * \param column_index_string Comma-delimited string listing columns to extract into covariates matrix */ @@ -335,7 +336,7 @@ class ForestDataset { } /*! * \brief Copy / load variance / case weights from CSV file - * + * * \param filename Name of the file (including any necessary path prefixes) * \param column_index Integer index of column containing weights */ @@ -344,58 +345,58 @@ class ForestDataset { has_var_weights_ = true; } /*! \brief Whether or not a `ForestDataset` has (yet) loaded covariate data */ - inline bool HasCovariates() {return has_covariates_;} + inline bool HasCovariates() { return has_covariates_; } /*! \brief Whether or not a `ForestDataset` has (yet) loaded basis data */ - inline bool HasBasis() {return has_basis_;} + inline bool HasBasis() { return has_basis_; } /*! \brief Whether or not a `ForestDataset` has (yet) loaded variance weights */ - inline bool HasVarWeights() {return has_var_weights_;} + inline bool HasVarWeights() { return has_var_weights_; } /*! \brief Number of observations (rows) in the dataset */ - inline data_size_t NumObservations() {return num_observations_;} + inline data_size_t NumObservations() { return num_observations_; } /*! \brief Number of covariate columns in the dataset */ - inline int NumCovariates() {return num_covariates_;} + inline int NumCovariates() { return num_covariates_; } /*! \brief Number of bases in the dataset. This is 0 if the dataset has not been provided a basis matrix. */ - inline int NumBasis() {return num_basis_;} + inline int NumBasis() { return num_basis_; } /*! * \brief Returns a dataset's covariate value stored at (`row`, `col`) - * + * * \param row Row number to query in the covariate matrix * \param col Column number to query in the covariate matrix */ - inline double CovariateValue(data_size_t row, int col) {return covariates_.GetElement(row, col);} + inline double CovariateValue(data_size_t row, int col) { return covariates_.GetElement(row, col); } /*! * \brief Returns a dataset's basis value stored at (`row`, `col`) - * + * * \param row Row number to query in the basis matrix * \param col Column number to query in the basis matrix */ - inline double BasisValue(data_size_t row, int col) {return basis_.GetElement(row, col);} + inline double BasisValue(data_size_t row, int col) { return basis_.GetElement(row, col); } /*! * \brief Returns a dataset's variance weight stored at element `row` - * + * * \param row Index to query in the weight vector */ - inline double VarWeightValue(data_size_t row) {return var_weights_.GetElement(row);} + inline double VarWeightValue(data_size_t row) { return var_weights_.GetElement(row); } /*! * \brief Return a reference to the raw `Eigen::MatrixXd` storing the covariate data - * + * * \return Reference to internal Eigen::MatrixXd */ - inline Eigen::MatrixXd& GetCovariates() {return covariates_.GetData();} + inline Eigen::MatrixXd& GetCovariates() { return covariates_.GetData(); } /*! * \brief Return a reference to the raw `Eigen::MatrixXd` storing the basis data - * + * * \return Reference to internal Eigen::MatrixXd */ - inline Eigen::MatrixXd& GetBasis() {return basis_.GetData();} + inline Eigen::MatrixXd& GetBasis() { return basis_.GetData(); } /*! * \brief Return a reference to the raw `Eigen::VectorXd` storing the variance weights - * + * * \return Reference to internal Eigen::VectorXd */ - inline Eigen::VectorXd& GetVarWeights() {return var_weights_.GetData();} + inline Eigen::VectorXd& GetVarWeights() { return var_weights_.GetData(); } /*! * \brief Update the data in the internal basis matrix to new values stored in a raw double array - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix * \param num_row Number of rows in the basis matrix * \param num_col Number of columns in the basis matrix @@ -408,7 +409,7 @@ class ForestDataset { double temp_value; for (data_size_t i = 0; i < num_row; ++i) { for (int j = 0; j < num_col; ++j) { - if (is_row_major){ + if (is_row_major) { // Numpy 2-d arrays are stored in "row major" order temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); } else { @@ -431,8 +432,10 @@ class ForestDataset { // Copy data from R / Python process memory to Eigen vector double temp_value; for (data_size_t i = 0; i < num_row; ++i) { - if (exponentiate) temp_value = std::exp(static_cast(*(data_ptr + i))); - else temp_value = static_cast(*(data_ptr + i)); + if (exponentiate) + temp_value = std::exp(static_cast(*(data_ptr + i))); + else + temp_value = static_cast(*(data_ptr + i)); var_weights_.SetElement(i, temp_value); } } @@ -466,11 +469,13 @@ class ForestDataset { */ void SetVarWeightValue(data_size_t row_id, double new_value, bool exponentiate = true) { CHECK(has_var_weights_); - if (exponentiate) var_weights_.SetElement(row_id, std::exp(new_value)); - else var_weights_.SetElement(row_id, new_value); + if (exponentiate) + var_weights_.SetElement(row_id, std::exp(new_value)); + else + var_weights_.SetElement(row_id, new_value); } - /*! - * \brief Auxiliary data management methods + /*! + * \brief Auxiliary data management methods * Methods to initialize, get, and set auxiliary data for BART models with more structure than the ``classic`` conjugate-Gaussian leaf BART model */ void AddAuxiliaryDimension(int dim_size) { @@ -506,9 +511,9 @@ class ForestDataset { bool has_basis_{false}; bool has_var_weights_{false}; - /*! - * \brief Vector of vectors to track (potentially jagged) auxiliary data for complex BART models - */ + /*! + * \brief Vector of vectors to track (potentially jagged) auxiliary data for complex BART models + */ std::vector> auxiliary_data_; int num_auxiliary_dims_{0}; bool has_auxiliary_data_{false}; @@ -522,30 +527,30 @@ class RandomEffectsDataset { ~RandomEffectsDataset() {} /*! * \brief Copy / load basis matrix from raw memory buffer (often pointer to data in a R matrix or numpy array) - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix * \param num_row Number of rows in the basis matrix * \param num_col Number of columns in the basis matrix * \param is_row_major Whether or not the data in `data_ptr` are organized in a row-major or column-major fashion */ - void AddBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { + void AddBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { basis_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major); num_basis_ = num_col; has_basis_ = true; } /*! * \brief Copy / load variance weights from raw memory buffer (often pointer to data in a R vector or numpy array) - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing weights * \param num_row Number of rows in the weight vector - */ + */ void AddVarianceWeights(double* data_ptr, data_size_t num_row) { var_weights_ = ColumnVector(data_ptr, num_row); has_var_weights_ = true; } /*! * \brief Update the data in the internal basis matrix to new values stored in a raw double array - * + * * \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix * \param num_row Number of rows in the basis matrix * \param num_col Number of columns in the basis matrix @@ -558,7 +563,7 @@ class RandomEffectsDataset { double temp_value; for (data_size_t i = 0; i < num_row; ++i) { for (int j = 0; j < num_col; ++j) { - if (is_row_major){ + if (is_row_major) { // Numpy 2-d arrays are stored in "row major" order temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); } else { @@ -581,8 +586,10 @@ class RandomEffectsDataset { // Copy data from R / Python process memory to Eigen vector double temp_value; for (data_size_t i = 0; i < num_row; ++i) { - if (exponentiate) temp_value = std::exp(static_cast(*(data_ptr + i))); - else temp_value = static_cast(*(data_ptr + i)); + if (exponentiate) + temp_value = std::exp(static_cast(*(data_ptr + i))); + else + temp_value = static_cast(*(data_ptr + i)); var_weights_.SetElement(i, temp_value); } } @@ -603,8 +610,8 @@ class RandomEffectsDataset { } /*! * \brief Copy / load group indices for random effects - * - * \param group_labels Vector of integers with as many elements as `num_row` in the basis matrix, + * + * \param group_labels Vector of integers with as many elements as `num_row` in the basis matrix, * where each element corresponds to the group label for a given observation. */ void AddGroupLabels(std::vector& group_labels) { @@ -612,52 +619,53 @@ class RandomEffectsDataset { has_group_labels_ = true; } /*! \brief Number of observations (rows) in the dataset */ - inline data_size_t NumObservations() {return basis_.NumRows();} + inline data_size_t NumObservations() { return basis_.NumRows(); } /*! \brief Number of columns of the basis vector in the dataset */ - inline int NumBases() {return basis_.NumCols();} + inline int NumBases() { return basis_.NumCols(); } /*! \brief Whether or not a `RandomEffectsDataset` has (yet) loaded basis data */ - inline bool HasBasis() {return has_basis_;} + inline bool HasBasis() { return has_basis_; } /*! \brief Whether or not a `RandomEffectsDataset` has (yet) loaded variance weights */ - inline bool HasVarWeights() {return has_var_weights_;} + inline bool HasVarWeights() { return has_var_weights_; } /*! \brief Whether or not a `RandomEffectsDataset` has (yet) loaded group labels */ - inline bool HasGroupLabels() {return has_group_labels_;} + inline bool HasGroupLabels() { return has_group_labels_; } /*! * \brief Returns a dataset's basis value stored at (`row`, `col`) - * + * * \param row Row number to query in the basis matrix * \param col Column number to query in the basis matrix */ - inline double BasisValue(data_size_t row, int col) {return basis_.GetElement(row, col);} + inline double BasisValue(data_size_t row, int col) { return basis_.GetElement(row, col); } /*! * \brief Returns a dataset's variance weight stored at element `row` - * + * * \param row Index to query in the weight vector */ - inline double VarWeightValue(data_size_t row) {return var_weights_.GetElement(row);} + inline double VarWeightValue(data_size_t row) { return var_weights_.GetElement(row); } /*! * \brief Returns a dataset's group label stored at element `row` - * + * * \param row Index to query in the group label vector */ - inline int32_t GroupId(data_size_t row) {return group_labels_[row];} + inline int32_t GroupId(data_size_t row) { return group_labels_[row]; } /*! * \brief Return a reference to the raw `Eigen::MatrixXd` storing the basis data - * + * * \return Reference to internal Eigen::MatrixXd */ - inline Eigen::MatrixXd& GetBasis() {return basis_.GetData();} + inline Eigen::MatrixXd& GetBasis() { return basis_.GetData(); } /*! * \brief Return a reference to the raw `Eigen::VectorXd` storing the variance weights - * + * * \return Reference to internal Eigen::VectorXd */ - inline Eigen::VectorXd& GetVarWeights() {return var_weights_.GetData();} + inline Eigen::VectorXd& GetVarWeights() { return var_weights_.GetData(); } /*! * \brief Return a reference to the raw `std::vector` storing the group labels - * + * * \return Reference to internal std::vector */ - inline std::vector& GetGroupLabels() {return group_labels_;} + inline std::vector& GetGroupLabels() { return group_labels_; } + private: ColumnMatrix basis_; ColumnVector var_weights_; @@ -668,8 +676,8 @@ class RandomEffectsDataset { bool has_group_labels_{false}; }; -/*! \} */ // end of data_group +/*! \} */ // end of data_group -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_DATA_H_ +#endif // STOCHTREE_DATA_H_ diff --git a/include/stochtree/discrete_sampler.h b/include/stochtree/discrete_sampler.h index a513d032..507875d5 100644 --- a/include/stochtree/discrete_sampler.h +++ b/include/stochtree/discrete_sampler.h @@ -12,7 +12,7 @@ namespace StochTree { -/*! \brief Sample without replacement according to a set of probability weights. +/*! \brief Sample without replacement according to a set of probability weights. * This template function is a C++ variant of numpy's implementation: * https://github.com/numpy/numpy/blob/031f44252d613f4524ad181e3eb2ae2791e22187/numpy/random/_generator.pyx#L925 */ @@ -23,19 +23,19 @@ void sample_without_replacement(container_type* output, prob_type* p, container_ std::vector indices(sample_size); std::vector unif_samples(sample_size); std::vector cdf(population_size); - + int fulfilled_sample_count = 0; int remaining_sample_count = sample_size - fulfilled_sample_count; while (fulfilled_sample_count < sample_size) { if (fulfilled_sample_count > 0) { for (int i = 0; i < fulfilled_sample_count; i++) p_copy[indices[i]] = 0.0; } - std::generate(unif_samples.begin(), unif_samples.begin() + remaining_sample_count, [&gen](){ + std::generate(unif_samples.begin(), unif_samples.begin() + remaining_sample_count, [&gen]() { return standard_uniform_draw_53bit(gen); }); std::partial_sum(p_copy.cbegin(), p_copy.cend(), cdf.begin()); for (int i = 0; i < cdf.size(); i++) { - cdf[i] = cdf[i] / cdf[cdf.size()-1]; + cdf[i] = cdf[i] / cdf[cdf.size() - 1]; } std::vector matches(remaining_sample_count); for (int i = 0; i < remaining_sample_count; i++) { @@ -60,6 +60,6 @@ void sample_without_replacement(container_type* output, prob_type* p, container_ } } -} +} // namespace StochTree -#endif // STOCHTREE_DISCRETE_SAMPLER_H_ +#endif // STOCHTREE_DISCRETE_SAMPLER_H_ diff --git a/include/stochtree/distributions.h b/include/stochtree/distributions.h index f946975a..f31b7b16 100644 --- a/include/stochtree/distributions.h +++ b/include/stochtree/distributions.h @@ -5,7 +5,7 @@ /*! * \brief A collection of random number generation utilities. * - * This file is vendored from a broader C++ / R distribution + * This file is vendored from a broader C++ / R distribution * library, where the distributions are subject to rigorous testing. * https://github.com/andrewherren/cpp11_r_rng */ @@ -68,9 +68,9 @@ class standard_normal { /*! * Stateless standard normal sampler implementing Marsaglia's polar method. * Without caching, this is half as fast as other methods for repeated normal sampling, - * but this might be acceptable in cases where a relatively small number of + * but this might be acceptable in cases where a relatively small number of * normal draws is desired. - * + * * Reference: https://en.wikipedia.org/wiki/Marsaglia_polar_method */ inline double sample_standard_normal(double mean, double sd, std::mt19937& gen) { @@ -130,10 +130,10 @@ class gamma_sampler { v = v * v * v; double u = standard_uniform_draw_53bit(gen); if (u < 1.0 - 0.0331 * (x * x) * (x * x)) { - return b * v * scale; + return b * v * scale; } if (std::log(u) < 0.5 * x * x + b * (1.0 - v + std::log(v))) { - return b * v * scale; + return b * v * scale; } } } else { @@ -202,23 +202,23 @@ inline double sample_gamma(std::mt19937& gen, double shape, double scale) { while (true) { double x, v; do { - // Marsaglia's polar method for standard normal + // Marsaglia's polar method for standard normal double u1, u2, s; do { u1 = standard_uniform_draw_53bit(gen) * 2.0 - 1.0; u2 = standard_uniform_draw_53bit(gen) * 2.0 - 1.0; s = u1 * u1 + u2 * u2; } while (s >= 1.0 || s == 0.0); - x = u1 * std::sqrt(-2.0 * std::log(s) / s); + x = u1 * std::sqrt(-2.0 * std::log(s) / s); v = 1.0 + c * x; } while (v <= 0.0); v = v * v * v; double u = standard_uniform_draw_53bit(gen); if (u < 1.0 - 0.0331 * (x * x) * (x * x)) { - return b * v * scale; + return b * v * scale; } if (std::log(u) < 0.5 * x * x + b * (1.0 - v + std::log(v))) { - return b * v * scale; + return b * v * scale; } } } else { @@ -228,13 +228,13 @@ inline double sample_gamma(std::mt19937& gen, double shape, double scale) { /*! * Walker-Vose alias method for sampling with replacement from a weighted discrete distribution. - * + * * Simplified from https://github.com/boostorg/random/blob/develop/include/boost/random/discrete_distribution.hpp * Other references: https://en.wikipedia.org/wiki/Alias_method */ class walker_vose { public: - template + template walker_vose(Iterator first, Iterator last) { n_ = std::distance(first, last); probability_.resize(n_); @@ -245,7 +245,7 @@ class walker_vose { for (auto it = first; it != last; ++it) { sum += *it; } - + // Build alias table using Walker's algorithm std::vector p(n_); std::vector below_average, above_average; @@ -258,33 +258,35 @@ class walker_vose { above_average.push_back(i); } } - + while (!below_average.empty() && !above_average.empty()) { - int j = below_average.back(); below_average.pop_back(); - int i = above_average.back(); above_average.pop_back(); - + int j = below_average.back(); + below_average.pop_back(); + int i = above_average.back(); + above_average.pop_back(); + probability_[j] = p[j]; alias_[j] = i; p[i] = (p[i] + p[j]) - 1.0; - + if (p[i] < 1.0) { below_average.push_back(i); } else { above_average.push_back(i); } } - + while (!above_average.empty()) { probability_[above_average.back()] = 1.0; above_average.pop_back(); } - + while (!below_average.empty()) { probability_[below_average.back()] = 1.0; below_average.pop_back(); } } - + int operator()(std::mt19937& gen) { double u = standard_uniform_draw_53bit(gen); int i = static_cast(u * n_); @@ -323,6 +325,6 @@ inline int sample_discrete_stateless(std::mt19937& gen, std::vector& wei return weights.size() - 1; } -} +} // namespace StochTree -#endif // STOCHTREE_DISTRIBUTIONS_H \ No newline at end of file +#endif // STOCHTREE_DISTRIBUTIONS_H \ No newline at end of file diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index 4f6ddf42..449b7ea6 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -1,10 +1,10 @@ /*! * Copyright (c) 2024 stochtree authors. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. - * - * Inspired by the design of the Learner, GBTreeModel, and GBTree classes in xgboost, + * + * Inspired by the design of the Learner, GBTreeModel, and GBTree classes in xgboost, * released under the Apache license with the following copyright: - * + * * Copyright 2015-2023 by XGBoost Contributors */ #ifndef STOCHTREE_ENSEMBLE_H_ @@ -32,7 +32,7 @@ class TreeEnsemble { public: /*! * \brief Initialize a new TreeEnsemble - * + * * \param num_trees Number of trees in a forest * \param output_dimension Dimension of the leaf node parameter * \param is_leaf_constant Whether or not the leaves of each tree are treated as "constant." If true, then predicting from an ensemble is simply a matter or determining which leaf node an observation falls into. If false, prediction will multiply a leaf node's parameter(s) for a given observation by a basis vector. @@ -51,10 +51,10 @@ class TreeEnsemble { is_leaf_constant_ = is_leaf_constant; is_exponentiated_ = is_exponentiated; } - + /*! * \brief Initialize an ensemble based on the state of an existing ensemble - * + * * \param ensemble `TreeEnsemble` used to initialize the current ensemble */ TreeEnsemble(TreeEnsemble& ensemble) { @@ -74,12 +74,12 @@ class TreeEnsemble { this->CloneFromExistingTree(j, tree); } } - + ~TreeEnsemble() {} /*! * \brief Combine two forests into a single forest by merging their trees - * + * * \param ensemble Reference to another `TreeEnsemble` that will be merged into the current ensemble */ void MergeForest(TreeEnsemble& ensemble) { @@ -103,7 +103,7 @@ class TreeEnsemble { /*! * \brief Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. - * + * * \param constant_value Value that will be added to every leaf of every tree */ void AddValueToLeaves(double constant_value) { @@ -115,7 +115,7 @@ class TreeEnsemble { /*! * \brief Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves. - * + * * \param constant_multiple Value that will be multiplied by every leaf of every tree */ void MultiplyLeavesByValue(double constant_multiple) { @@ -127,9 +127,9 @@ class TreeEnsemble { /*! * \brief Return a pointer to a tree in the forest - * + * * \param i Index (0-based) of a tree to be queried - * \return Tree* + * \return Tree* */ inline Tree* GetTree(int i) { return trees_[i].get(); @@ -147,7 +147,7 @@ class TreeEnsemble { /*! * \brief Reset a single tree in an ensemble * \todo Consider refactoring this and `ResetInitTree` - * + * * \param i Index (0-based) of the tree to be reset */ inline void ResetTree(int i) { @@ -157,7 +157,7 @@ class TreeEnsemble { /*! * \brief Reset a single tree in an ensemble * \todo Consider refactoring this and `ResetTree` - * + * * \param i Index (0-based) of the tree to be reset */ inline void ResetInitTree(int i) { @@ -167,7 +167,7 @@ class TreeEnsemble { /*! * \brief Clone a single tree in an ensemble from an existing tree, overwriting current tree - * + * * \param i Index of the tree to be overwritten * \param tree Pointer to tree used to clone tree `i` */ @@ -177,7 +177,7 @@ class TreeEnsemble { /*! * \brief Reset an ensemble to clone another ensemble - * + * * \param ensemble Reference to an existing `TreeEnsemble` */ inline void ReconstituteFromForest(TreeEnsemble& ensemble) { @@ -214,12 +214,12 @@ class TreeEnsemble { PredictRawInplace(dataset, output, 0); return output; } - - inline void PredictInplace(ForestDataset& dataset, std::vector &output, data_size_t offset = 0) { + + inline void PredictInplace(ForestDataset& dataset, std::vector& output, data_size_t offset = 0) { PredictInplace(dataset, output, 0, trees_.size(), offset); } - inline void PredictInplace(ForestDataset& dataset, std::vector &output, + inline void PredictInplace(ForestDataset& dataset, std::vector& output, int tree_begin, int tree_end, data_size_t offset = 0) { if (is_leaf_constant_) { PredictInplace(dataset.GetCovariates(), output, tree_begin, tree_end, offset); @@ -229,11 +229,11 @@ class TreeEnsemble { } } - inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector &output, data_size_t offset = 0) { + inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector& output, data_size_t offset = 0) { PredictInplace(covariates, basis, output, 0, trees_.size(), offset); } - inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector &output, + inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector& output, int tree_begin, int tree_end, data_size_t offset = 0) { double pred; CHECK_EQ(covariates.rows(), basis.rows()); @@ -247,22 +247,24 @@ class TreeEnsemble { for (data_size_t i = 0; i < n; i++) { pred = 0.0; for (size_t j = tree_begin; j < tree_end; j++) { - auto &tree = *trees_[j]; + auto& tree = *trees_[j]; std::int32_t nidx = EvaluateTree(tree, covariates, i); for (int32_t k = 0; k < output_dimension_; k++) { pred += tree.LeafValue(nidx, k) * basis(i, k); } } - if (is_exponentiated_) output[i + offset] = std::exp(pred); - else output[i + offset] = pred; + if (is_exponentiated_) + output[i + offset] = std::exp(pred); + else + output[i + offset] = pred; } } - inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector &output, data_size_t offset = 0) { + inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector& output, data_size_t offset = 0) { PredictInplace(covariates, output, 0, trees_.size(), offset); } - inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector &output, int tree_begin, int tree_end, data_size_t offset = 0) { + inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector& output, int tree_begin, int tree_end, data_size_t offset = 0) { double pred; data_size_t n = covariates.rows(); data_size_t total_output_size = n; @@ -272,21 +274,23 @@ class TreeEnsemble { for (data_size_t i = 0; i < n; i++) { pred = 0.0; for (size_t j = tree_begin; j < tree_end; j++) { - auto &tree = *trees_[j]; + auto& tree = *trees_[j]; std::int32_t nidx = EvaluateTree(tree, covariates, i); pred += tree.LeafValue(nidx, 0); } - if (is_exponentiated_) output[i + offset] = std::exp(pred); - else output[i + offset] = pred; + if (is_exponentiated_) + output[i + offset] = std::exp(pred); + else + output[i + offset] = pred; } } - inline void PredictRawInplace(ForestDataset& dataset, std::vector &output, data_size_t offset = 0) { + inline void PredictRawInplace(ForestDataset& dataset, std::vector& output, data_size_t offset = 0) { PredictRawInplace(dataset, output, 0, trees_.size(), offset); } - inline void PredictRawInplace(ForestDataset& dataset, std::vector &output, - int tree_begin, int tree_end, data_size_t offset = 0) { + inline void PredictRawInplace(ForestDataset& dataset, std::vector& output, + int tree_begin, int tree_end, data_size_t offset = 0) { double pred; Eigen::MatrixXd covariates = dataset.GetCovariates(); CHECK_EQ(output_dimension_, trees_[0]->OutputDimension()); @@ -299,11 +303,11 @@ class TreeEnsemble { for (int32_t k = 0; k < output_dimension_; k++) { pred = 0.0; for (size_t j = tree_begin; j < tree_end; j++) { - auto &tree = *trees_[j]; + auto& tree = *trees_[j]; int32_t nidx = EvaluateTree(tree, covariates, i); pred += tree.LeafValue(nidx, k); } - output[i*output_dimension_ + k + offset] = pred; + output[i * output_dimension_ + k + offset] = pred; } } } @@ -380,30 +384,30 @@ class TreeEnsemble { } /*! - * \brief Obtain a 0-based "maximum" leaf index for an ensemble, which is equivalent to the sum of the + * \brief Obtain a 0-based "maximum" leaf index for an ensemble, which is equivalent to the sum of the * number of leaves in each tree. This is used in conjunction with `PredictLeafIndicesInplace`, * which returns an observation-specific leaf index for every observation-tree pair. */ int GetMaxLeafIndex() { int max_leaf = 0; for (int j = 0; j < num_trees_; j++) { - auto &tree = *trees_[j]; + auto& tree = *trees_[j]; max_leaf += tree.NumLeaves(); } return max_leaf; } /*! - * \brief Obtain a 0-based leaf index for every tree in an ensemble and for each - * observation in a ForestDataset. Internally, trees are stored as essentially - * vectors of node information, and the leaves_ vector gives us node IDs for every - * leaf in the tree. Here, we would like to know, for every observation in a dataset, - * which leaf number it is mapped to. Since the leaf numbers themselves - * do not carry any information, we renumber them from 0 to `leaves_.size()-1`. - * We compute this at the tree-level and coordinate this computation at the + * \brief Obtain a 0-based leaf index for every tree in an ensemble and for each + * observation in a ForestDataset. Internally, trees are stored as essentially + * vectors of node information, and the leaves_ vector gives us node IDs for every + * leaf in the tree. Here, we would like to know, for every observation in a dataset, + * which leaf number it is mapped to. Since the leaf numbers themselves + * do not carry any information, we renumber them from 0 to `leaves_.size()-1`. + * We compute this at the tree-level and coordinate this computation at the * ensemble level. * - * Note: this assumes the creation of a vector of column indices of size + * Note: this assumes the creation of a vector of column indices of size * `dataset.NumObservations()` x `ensemble.NumTrees()` * \param ForestDataset Dataset with which to predict leaf indices from the tree * \param output Vector of length num_trees*n which stores the leaf node prediction @@ -415,16 +419,16 @@ class TreeEnsemble { } /*! - * \brief Obtain a 0-based leaf index for every tree in an ensemble and for each - * observation in a ForestDataset. Internally, trees are stored as essentially - * vectors of node information, and the leaves_ vector gives us node IDs for every - * leaf in the tree. Here, we would like to know, for every observation in a dataset, - * which leaf number it is mapped to. Since the leaf numbers themselves - * do not carry any information, we renumber them from 0 to `leaves_.size()-1`. - * We compute this at the tree-level and coordinate this computation at the + * \brief Obtain a 0-based leaf index for every tree in an ensemble and for each + * observation in a ForestDataset. Internally, trees are stored as essentially + * vectors of node information, and the leaves_ vector gives us node IDs for every + * leaf in the tree. Here, we would like to know, for every observation in a dataset, + * which leaf number it is mapped to. Since the leaf numbers themselves + * do not carry any information, we renumber them from 0 to `leaves_.size()-1`. + * We compute this at the tree-level and coordinate this computation at the * ensemble level. * - * Note: this assumes the creation of a vector of column indices of size + * Note: this assumes the creation of a vector of column indices of size * `dataset.NumObservations()` x `ensemble.NumTrees()` * \param covariates Matrix of covariates * \param output Vector of length num_trees*n which stores the leaf node prediction @@ -432,11 +436,11 @@ class TreeEnsemble { * \param n Size of dataset */ void PredictLeafIndicesInplace(Eigen::Map>& covariates, std::vector& output, int num_trees, data_size_t n) { - CHECK_GE(output.size(), num_trees*n); + CHECK_GE(output.size(), num_trees * n); int offset = 0; int max_leaf = 0; for (int j = 0; j < num_trees; j++) { - auto &tree = *trees_[j]; + auto& tree = *trees_[j]; int num_leaves = tree.NumLeaves(); tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf); offset += n; @@ -445,13 +449,13 @@ class TreeEnsemble { } /*! - * \brief Obtain a 0-based leaf index for every tree in an ensemble and for each - * observation in a ForestDataset. Internally, trees are stored as essentially - * vectors of node information, and the leaves_ vector gives us node IDs for every - * leaf in the tree. Here, we would like to know, for every observation in a dataset, - * which leaf number it is mapped to. Since the leaf numbers themselves - * do not carry any information, we renumber them from 0 to `leaves_.size()-1`. - * We compute this at the tree-level and coordinate this computation at the + * \brief Obtain a 0-based leaf index for every tree in an ensemble and for each + * observation in a ForestDataset. Internally, trees are stored as essentially + * vectors of node information, and the leaves_ vector gives us node IDs for every + * leaf in the tree. Here, we would like to know, for every observation in a dataset, + * which leaf number it is mapped to. Since the leaf numbers themselves + * do not carry any information, we renumber them from 0 to `leaves_.size()-1`. + * We compute this at the tree-level and coordinate this computation at the * ensemble level. * * Note: this assumes the creation of a matrix of column indices with `num_trees*n` rows @@ -462,14 +466,14 @@ class TreeEnsemble { * \param num_trees Number of trees in an ensemble * \param n Size of dataset */ - void PredictLeafIndicesInplace(Eigen::Map>& covariates, - Eigen::Map>& output, + void PredictLeafIndicesInplace(Eigen::Map>& covariates, + Eigen::Map>& output, int column_ind, int num_trees, data_size_t n) { - CHECK_GE(output.size(), num_trees*n); + CHECK_GE(output.size(), num_trees * n); int offset = 0; int max_leaf = 0; for (int j = 0; j < num_trees; j++) { - auto &tree = *trees_[j]; + auto& tree = *trees_[j]; int num_leaves = tree.NumLeaves(); tree.PredictLeafIndexInplace(covariates, output, column_ind, offset, max_leaf); offset += n; @@ -478,16 +482,16 @@ class TreeEnsemble { } /*! - * \brief Obtain a 0-based leaf index for every tree in an ensemble and for each - * observation in a ForestDataset. Internally, trees are stored as essentially - * vectors of node information, and the leaves_ vector gives us node IDs for every - * leaf in the tree. Here, we would like to know, for every observation in a dataset, - * which leaf number it is mapped to. Since the leaf numbers themselves - * do not carry any information, we renumber them from 0 to `leaves_.size()-1`. - * We compute this at the tree-level and coordinate this computation at the + * \brief Obtain a 0-based leaf index for every tree in an ensemble and for each + * observation in a ForestDataset. Internally, trees are stored as essentially + * vectors of node information, and the leaves_ vector gives us node IDs for every + * leaf in the tree. Here, we would like to know, for every observation in a dataset, + * which leaf number it is mapped to. Since the leaf numbers themselves + * do not carry any information, we renumber them from 0 to `leaves_.size()-1`. + * We compute this at the tree-level and coordinate this computation at the * ensemble level. * - * Note: this assumes the creation of a vector of column indices of size + * Note: this assumes the creation of a vector of column indices of size * `dataset.NumObservations()` x `ensemble.NumTrees()` * \param ForestDataset Dataset with which to predict leaf indices from the tree * \param output Vector of length num_trees*n which stores the leaf node prediction @@ -495,11 +499,11 @@ class TreeEnsemble { * \param n Size of dataset */ void PredictLeafIndicesInplace(Eigen::MatrixXd& covariates, std::vector& output, int num_trees, data_size_t n) { - CHECK_GE(output.size(), num_trees*n); + CHECK_GE(output.size(), num_trees * n); int offset = 0; int max_leaf = 0; for (int j = 0; j < num_trees; j++) { - auto &tree = *trees_[j]; + auto& tree = *trees_[j]; int num_leaves = tree.NumLeaves(); tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf); offset += n; @@ -514,7 +518,7 @@ class TreeEnsemble { std::vector PredictLeafIndices(ForestDataset* dataset) { int num_trees = num_trees_; data_size_t n = dataset->NumObservations(); - std::vector output(n*num_trees); + std::vector output(n * num_trees); PredictLeafIndicesInplace(dataset, output, num_trees, n); return output; } @@ -532,10 +536,10 @@ class TreeEnsemble { tree_label = "tree_" + std::to_string(i); result_obj.emplace(tree_label, trees_[i]->to_json()); } - + return result_obj; } - + /*! \brief Load from JSON */ void from_json(const json& ensemble_json) { this->num_trees_ = ensemble_json.at("num_trees"); @@ -561,8 +565,8 @@ class TreeEnsemble { bool is_exponentiated_; }; -/*! \} */ // end of forest_group +/*! \} */ // end of forest_group -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_ENSEMBLE_H_ +#endif // STOCHTREE_ENSEMBLE_H_ diff --git a/include/stochtree/export.h b/include/stochtree/export.h index 4b651749..fb5b339b 100644 --- a/include/stochtree/export.h +++ b/include/stochtree/export.h @@ -1,11 +1,11 @@ /*! - * Export macros ensure that the C++ code can be used as a library cross-platform - * (declspec needed to load names from a DLL on windows) and can be wrapped in a + * Export macros ensure that the C++ code can be used as a library cross-platform + * (declspec needed to load names from a DLL on windows) and can be wrapped in a * C program. - * - * This code modifies (changing names of) the export macros in LightGBM, which carries + * + * This code modifies (changing names of) the export macros in LightGBM, which carries * the following copyright information: - * + * * Copyright (c) 2017 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ @@ -24,8 +24,8 @@ #define STOCHTREE_EXPORT __declspec(dllexport) #define STOCHTREE_C_EXPORT STOCHTREE_EXTERN_C __declspec(dllexport) #else -#define STOCHTREE_EXPORT __attribute__ ((visibility ("default"))) -#define STOCHTREE_C_EXPORT STOCHTREE_EXTERN_C __attribute__ ((visibility ("default"))) +#define STOCHTREE_EXPORT __attribute__((visibility("default"))) +#define STOCHTREE_C_EXPORT STOCHTREE_EXTERN_C __attribute__((visibility("default"))) #endif #endif /** STOCHTREE_EXPORT_H_ **/ diff --git a/include/stochtree/gamma_sampler.h b/include/stochtree/gamma_sampler.h index 1c524f4f..53d2f332 100644 --- a/include/stochtree/gamma_sampler.h +++ b/include/stochtree/gamma_sampler.h @@ -12,11 +12,11 @@ class GammaSampler { GammaSampler() {} ~GammaSampler() {} double Sample(double a, double b, std::mt19937& gen, bool rate_param = true) { - double scale = rate_param ? 1./b : b; + double scale = rate_param ? 1. / b : b; return sample_gamma(gen, a, scale); } }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_GAMMA_SAMPLER_H_ \ No newline at end of file +#endif // STOCHTREE_GAMMA_SAMPLER_H_ \ No newline at end of file diff --git a/include/stochtree/ig_sampler.h b/include/stochtree/ig_sampler.h index 669b9e56..6da28fad 100644 --- a/include/stochtree/ig_sampler.h +++ b/include/stochtree/ig_sampler.h @@ -13,15 +13,15 @@ class InverseGammaSampler { ~InverseGammaSampler() {} double Sample(double a, double b, std::mt19937& gen, bool scale_param = true) { // C++ standard library provides a gamma distribution with scale - // parameter, but the correspondence between gamma and IG is that + // parameter, but the correspondence between gamma and IG is that // 1 / gamma(a,b) ~ IG(a,b) when b is a __rate__ parameter. - // Before sampling, we convert ig_scale to a gamma scale parameter by + // Before sampling, we convert ig_scale to a gamma scale parameter by // taking its multiplicative inverse. - double gamma_scale = scale_param ? 1./b : b; - return (1/sample_gamma(gen, a, gamma_scale)); + double gamma_scale = scale_param ? 1. / b : b; + return (1 / sample_gamma(gen, a, gamma_scale)); } }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_IG_SAMPLER_H_ \ No newline at end of file +#endif // STOCHTREE_IG_SAMPLER_H_ \ No newline at end of file diff --git a/include/stochtree/io.h b/include/stochtree/io.h index 55963946..eeb259e7 100644 --- a/include/stochtree/io.h +++ b/include/stochtree/io.h @@ -7,8 +7,8 @@ * parser.h * pipeline_reader.h * text_reader.h - * - * LightGBM is MIT licensed and released with the following copyright header + * + * LightGBM is MIT licensed and released with the following copyright header * (with different copyright years in different files): * * Copyright (c) 2016 Microsoft Corporation. All rights reserved. @@ -42,9 +42,9 @@ namespace StochTree { const size_t kGbs = size_t(1024) * 1024 * 1024; /*! -* \brief Contains some operation for an array, e.g. ArgMax, TopK. -*/ -template + * \brief Contains some operation for an array, e.g. ArgMax, TopK. + */ +template class ArrayArgs { public: inline static size_t ArgMax(const std::vector& array) { @@ -116,18 +116,35 @@ class ArrayArgs { std::vector& ref = *arr; VAL_T v = ref[end - 1]; for (;;) { - while (ref[++i] > v) {} - while (v > ref[--j]) { if (j == start) { break; } } - if (i >= j) { break; } + while (ref[++i] > v) { + } + while (v > ref[--j]) { + if (j == start) { + break; + } + } + if (i >= j) { + break; + } std::swap(ref[i], ref[j]); - if (ref[i] == v) { p++; std::swap(ref[p], ref[i]); } - if (v == ref[j]) { q--; std::swap(ref[j], ref[q]); } + if (ref[i] == v) { + p++; + std::swap(ref[p], ref[i]); + } + if (v == ref[j]) { + q--; + std::swap(ref[j], ref[q]); + } } std::swap(ref[i], ref[end - 1]); j = i - 1; i = i + 1; - for (int k = start; k <= p; k++, j--) { std::swap(ref[k], ref[j]); } - for (int k = end - 2; k >= q; k--, i++) { std::swap(ref[i], ref[k]); } + for (int k = start; k <= p; k++, j--) { + std::swap(ref[k], ref[j]); + } + for (int k = end - 2; k >= q; k--, i++) { + std::swap(ref[i], ref[k]); + } *l = j; *r = i; } @@ -193,24 +210,24 @@ class ArrayArgs { }; /*! - * \brief An interface for serializing binary data to a buffer - */ + * \brief An interface for serializing binary data to a buffer + */ struct BinaryWriter { /*! - * \brief Append data to this binary target - * \param data Buffer to write from - * \param bytes Number of bytes to write from buffer - * \return Number of bytes written - */ + * \brief Append data to this binary target + * \param data Buffer to write from + * \param bytes Number of bytes to write from buffer + * \return Number of bytes written + */ virtual size_t Write(const void* data, size_t bytes) = 0; /*! - * \brief Append data to this binary target aligned on a given byte size boundary - * \param data Buffer to write from - * \param bytes Number of bytes to write from buffer - * \param alignment The size of bytes to align to in whole increments - * \return Number of bytes written - */ + * \brief Append data to this binary target aligned on a given byte size boundary + * \param data Buffer to write from + * \param bytes Number of bytes to write from buffer + * \param alignment The size of bytes to align to in whole increments + * \return Number of bytes written + */ size_t AlignedWrite(const void* data, size_t bytes, size_t alignment = 8) { auto ret = Write(data, bytes); if (bytes % alignment != 0) { @@ -222,11 +239,11 @@ struct BinaryWriter { } /*! - * \brief The aligned size of a buffer length. - * \param bytes The number of bytes in a buffer - * \param alignment The size of bytes to align to in whole increments - * \return Number of aligned bytes - */ + * \brief The aligned size of a buffer length. + * \param bytes The number of bytes in a buffer + * \param alignment The size of bytes to align to in whole increments + * \return Number of aligned bytes + */ static size_t AlignedSize(size_t bytes, size_t alignment = 8) { if (bytes % alignment == 0) { return bytes; @@ -301,42 +318,42 @@ class Parser { Parser() {} /*! - * \brief Constructor for customized parser. The constructor accepts content not path because need to save/load the config along with model string - */ + * \brief Constructor for customized parser. The constructor accepts content not path because need to save/load the config along with model string + */ explicit Parser(std::string) {} /*! \brief virtual destructor */ virtual ~Parser() {} /*! - * \brief Parse one line with label - * \param str One line record, string format, should end with '\0' - * \param out_features Output columns, store in (column_idx, values) - */ + * \brief Parse one line with label + * \param str One line record, string format, should end with '\0' + * \param out_features Output columns, store in (column_idx, values) + */ virtual void ParseOneLine(const char* str, std::vector>* out_features) const = 0; virtual int NumFeatures() const = 0; /*! - * \brief Create an object of parser, will auto choose the format depend on file - * \param filename One Filename of data - * \param header whether input file contains header - * \param num_features Pass num_features of this data file if you know, <=0 means don't know - * \param precise_float_parser using precise floating point number parsing if true - * \return Object of parser - */ + * \brief Create an object of parser, will auto choose the format depend on file + * \param filename One Filename of data + * \param header whether input file contains header + * \param num_features Pass num_features of this data file if you know, <=0 means don't know + * \param precise_float_parser using precise floating point number parsing if true + * \return Object of parser + */ static Parser* CreateParser(const char* filename, bool header, int num_features, bool precise_float_parser); }; -class CSVParser: public Parser { +class CSVParser : public Parser { public: explicit CSVParser(int total_columns, AtofFunc atof) - :total_columns_(total_columns), atof_(atof) { + : total_columns_(total_columns), atof_(atof) { } inline void ParseOneLine(const char* str, - std::vector>* out_features) const override { + std::vector>* out_features) const override { int idx = 0; double val = 0.0f; int offset = 0; @@ -365,22 +382,22 @@ class CSVParser: public Parser { }; /*! -* \brief A pipeline file reader, use 2 threads, one read block from file, the other process the block -*/ + * \brief A pipeline file reader, use 2 threads, one read block from file, the other process the block + */ class PipelineReader { public: /*! - * \brief Read data from a file, use pipeline methods - * \param filename Filename of data - * \process_fun Process function - */ + * \brief Read data from a file, use pipeline methods + * \param filename Filename of data + * \process_fun Process function + */ static size_t Read(const char* filename, int skip_bytes, const std::function& process_fun) { auto reader = VirtualFileReader::Make(filename); if (!reader->Init()) { return 0; } size_t cnt = 0; - const size_t buffer_size = 16 * 1024 * 1024; + const size_t buffer_size = 16 * 1024 * 1024; // buffer used for the process_fun auto buffer_process = std::vector(buffer_size); // buffer used for the file reading @@ -397,9 +414,9 @@ class PipelineReader { while (read_cnt > 0) { // start read thread std::thread read_worker = std::thread( - [=, &last_read_cnt, &reader, &buffer_read] { - last_read_cnt = reader->Read(buffer_read.data(), buffer_size); - }); + [=, &last_read_cnt, &reader, &buffer_read] { + last_read_cnt = reader->Read(buffer_read.data(), buffer_size); + }); // start process cnt += process_fun(buffer_process.data(), read_cnt); // wait for read thread @@ -413,18 +430,17 @@ class PipelineReader { }; /*! -* \brief Read text data from file -*/ -template + * \brief Read text data from file + */ +template class TextReader { public: /*! - * \brief Constructor - * \param filename Filename of data - * \param is_skip_first_line True if need to skip header - */ - TextReader(const char* filename, bool is_skip_first_line, size_t progress_interval_bytes = SIZE_MAX): - filename_(filename), is_skip_first_line_(is_skip_first_line), read_progress_interval_bytes_(progress_interval_bytes) { + * \brief Constructor + * \param filename Filename of data + * \param is_skip_first_line True if need to skip header + */ + TextReader(const char* filename, bool is_skip_first_line, size_t progress_interval_bytes = SIZE_MAX) : filename_(filename), is_skip_first_line_(is_skip_first_line), read_progress_interval_bytes_(progress_interval_bytes) { if (is_skip_first_line_) { auto reader = VirtualFileReader::Make(filename); if (!reader->Init()) { @@ -454,33 +470,33 @@ class TextReader { } } /*! - * \brief Destructor - */ + * \brief Destructor + */ ~TextReader() { Clear(); } /*! - * \brief Clear cached data - */ + * \brief Clear cached data + */ inline void Clear() { lines_.clear(); lines_.shrink_to_fit(); } /*! - * \brief return first line of data - */ + * \brief return first line of data + */ inline std::string first_line() { return first_line_; } /*! - * \brief Get text data that read from file - * \return Text data, store in std::vector by line - */ + * \brief Get text data that read from file + * \return Text data, store in std::vector by line + */ inline std::vector& Lines() { return lines_; } /*! - * \brief Get joined text data that read from file - * \return Text data, store in std::string, joined all lines by delimiter - */ + * \brief Get joined text data that read from file + * \return Text data, store in std::string, joined all lines by delimiter + */ inline std::string JoinedLines(std::string delimiter = "\n") { std::stringstream ss; for (auto line : lines_) { @@ -494,47 +510,48 @@ class TextReader { INDEX_T total_cnt = 0; size_t bytes_read = 0; PipelineReader::Read(filename_, skip_bytes_, - [&process_fun, &bytes_read, &total_cnt, this] - (const char* buffer_process, size_t read_cnt) { - size_t cnt = 0; - size_t i = 0; - size_t last_i = 0; - // skip the break between \r and \n - if (last_line_.size() == 0 && buffer_process[0] == '\n') { - i = 1; - last_i = i; - } - while (i < read_cnt) { - if (buffer_process[i] == '\n' || buffer_process[i] == '\r') { - if (last_line_.size() > 0) { - last_line_.append(buffer_process + last_i, i - last_i); - process_fun(total_cnt, last_line_.c_str(), last_line_.size()); - last_line_ = ""; - } else { - process_fun(total_cnt, buffer_process + last_i, i - last_i); - } - ++cnt; - ++i; - ++total_cnt; - // skip end of line - while ((buffer_process[i] == '\n' || buffer_process[i] == '\r') && i < read_cnt) { ++i; } - last_i = i; - } else { - ++i; - } - } - if (last_i != read_cnt) { - last_line_.append(buffer_process + last_i, read_cnt - last_i); - } - - size_t prev_bytes_read = bytes_read; - bytes_read += read_cnt; - if (prev_bytes_read / read_progress_interval_bytes_ < bytes_read / read_progress_interval_bytes_) { - Log::Debug("Read %.1f GBs from %s.", 1.0 * bytes_read / kGbs, filename_); - } - - return cnt; - }); + [&process_fun, &bytes_read, &total_cnt, this](const char* buffer_process, size_t read_cnt) { + size_t cnt = 0; + size_t i = 0; + size_t last_i = 0; + // skip the break between \r and \n + if (last_line_.size() == 0 && buffer_process[0] == '\n') { + i = 1; + last_i = i; + } + while (i < read_cnt) { + if (buffer_process[i] == '\n' || buffer_process[i] == '\r') { + if (last_line_.size() > 0) { + last_line_.append(buffer_process + last_i, i - last_i); + process_fun(total_cnt, last_line_.c_str(), last_line_.size()); + last_line_ = ""; + } else { + process_fun(total_cnt, buffer_process + last_i, i - last_i); + } + ++cnt; + ++i; + ++total_cnt; + // skip end of line + while ((buffer_process[i] == '\n' || buffer_process[i] == '\r') && i < read_cnt) { + ++i; + } + last_i = i; + } else { + ++i; + } + } + if (last_i != read_cnt) { + last_line_.append(buffer_process + last_i, read_cnt - last_i); + } + + size_t prev_bytes_read = bytes_read; + bytes_read += read_cnt; + if (prev_bytes_read / read_progress_interval_bytes_ < bytes_read / read_progress_interval_bytes_) { + Log::Debug("Read %.1f GBs from %s.", 1.0 * bytes_read / kGbs, filename_); + } + + return cnt; + }); // if last line of file doesn't contain end of line if (last_line_.size() > 0) { Log::Info("Warning: last line of %s has no end of line, still using this line", filename_); @@ -546,14 +563,14 @@ class TextReader { } /*! - * \brief Read all text data from file in memory - * \return number of lines of text data - */ + * \brief Read all text data from file in memory + * \return number of lines of text data + */ INDEX_T ReadAllLines() { return ReadAllAndProcess( - [=](INDEX_T, const char* buffer, size_t size) { - lines_.emplace_back(buffer, size); - }); + [=](INDEX_T, const char* buffer, size_t size) { + lines_.emplace_back(buffer, size); + }); } std::vector ReadContent(size_t* out_len) { @@ -577,8 +594,7 @@ class TextReader { INDEX_T SampleFromFile(Random* random, INDEX_T sample_cnt, std::vector* out_sampled_data) { INDEX_T cur_sample_cnt = 0; return ReadAllAndProcess([=, &random, &cur_sample_cnt, - &out_sampled_data] - (INDEX_T line_idx, const char* buffer, size_t size) { + &out_sampled_data](INDEX_T line_idx, const char* buffer, size_t size) { if (cur_sample_cnt < sample_cnt) { out_sampled_data->emplace_back(buffer, size); ++cur_sample_cnt; @@ -591,54 +607,52 @@ class TextReader { }); } /*! - * \brief Read part of text data from file in memory, use filter_fun to filter data - * \param filter_fun Function that perform data filter - * \param out_used_data_indices Store line indices that read text data - * \return The number of total data - */ + * \brief Read part of text data from file in memory, use filter_fun to filter data + * \param filter_fun Function that perform data filter + * \param out_used_data_indices Store line indices that read text data + * \return The number of total data + */ INDEX_T ReadAndFilterLines(const std::function& filter_fun, std::vector* out_used_data_indices) { out_used_data_indices->clear(); INDEX_T total_cnt = ReadAllAndProcess( - [&filter_fun, &out_used_data_indices, this] - (INDEX_T line_idx , const char* buffer, size_t size) { - bool is_used = filter_fun(line_idx); - if (is_used) { - out_used_data_indices->push_back(line_idx); - lines_.emplace_back(buffer, size); - } - }); + [&filter_fun, &out_used_data_indices, this](INDEX_T line_idx, const char* buffer, size_t size) { + bool is_used = filter_fun(line_idx); + if (is_used) { + out_used_data_indices->push_back(line_idx); + lines_.emplace_back(buffer, size); + } + }); return total_cnt; } INDEX_T SampleAndFilterFromFile(const std::function& filter_fun, std::vector* out_used_data_indices, - Random* random, INDEX_T sample_cnt, std::vector* out_sampled_data) { + Random* random, INDEX_T sample_cnt, std::vector* out_sampled_data) { INDEX_T cur_sample_cnt = 0; out_used_data_indices->clear(); INDEX_T total_cnt = ReadAllAndProcess( [=, &filter_fun, &out_used_data_indices, &random, &cur_sample_cnt, - &out_sampled_data] - (INDEX_T line_idx, const char* buffer, size_t size) { - bool is_used = filter_fun(line_idx); - if (is_used) { - out_used_data_indices->push_back(line_idx); - if (cur_sample_cnt < sample_cnt) { - out_sampled_data->emplace_back(buffer, size); - ++cur_sample_cnt; - } else { - const size_t idx = static_cast(random->NextInt(0, static_cast(out_used_data_indices->size()))); - if (idx < static_cast(sample_cnt)) { - out_sampled_data->operator[](idx) = std::string(buffer, size); + &out_sampled_data](INDEX_T line_idx, const char* buffer, size_t size) { + bool is_used = filter_fun(line_idx); + if (is_used) { + out_used_data_indices->push_back(line_idx); + if (cur_sample_cnt < sample_cnt) { + out_sampled_data->emplace_back(buffer, size); + ++cur_sample_cnt; + } else { + const size_t idx = static_cast(random->NextInt(0, static_cast(out_used_data_indices->size()))); + if (idx < static_cast(sample_cnt)) { + out_sampled_data->operator[](idx) = std::string(buffer, size); + } + } } - } - } - }); + }); return total_cnt; } INDEX_T CountLine() { return ReadAllAndProcess( - [=](INDEX_T, const char*, size_t) { - }); + [=](INDEX_T, const char*, size_t) { + }); } INDEX_T ReadAllAndProcessParallelWithFilter(const std::function&)>& process_fun, const std::function& filter_fun) { @@ -647,56 +661,57 @@ class TextReader { size_t bytes_read = 0; INDEX_T used_cnt = 0; PipelineReader::Read(filename_, skip_bytes_, - [&process_fun, &filter_fun, &total_cnt, &bytes_read, &used_cnt, this] - (const char* buffer_process, size_t read_cnt) { - size_t cnt = 0; - size_t i = 0; - size_t last_i = 0; - INDEX_T start_idx = used_cnt; - // skip the break between \r and \n - if (last_line_.size() == 0 && buffer_process[0] == '\n') { - i = 1; - last_i = i; - } - while (i < read_cnt) { - if (buffer_process[i] == '\n' || buffer_process[i] == '\r') { - if (last_line_.size() > 0) { - last_line_.append(buffer_process + last_i, i - last_i); - if (filter_fun(used_cnt, total_cnt)) { - lines_.push_back(last_line_); - ++used_cnt; - } - last_line_ = ""; - } else { - if (filter_fun(used_cnt, total_cnt)) { - lines_.emplace_back(buffer_process + last_i, i - last_i); - ++used_cnt; - } - } - ++cnt; - ++i; - ++total_cnt; - // skip end of line - while ((buffer_process[i] == '\n' || buffer_process[i] == '\r') && i < read_cnt) { ++i; } - last_i = i; - } else { - ++i; - } - } - process_fun(start_idx, lines_); - lines_.clear(); - if (last_i != read_cnt) { - last_line_.append(buffer_process + last_i, read_cnt - last_i); - } - - size_t prev_bytes_read = bytes_read; - bytes_read += read_cnt; - if (prev_bytes_read / read_progress_interval_bytes_ < bytes_read / read_progress_interval_bytes_) { - Log::Debug("Read %.1f GBs from %s.", 1.0 * bytes_read / kGbs, filename_); - } - - return cnt; - }); + [&process_fun, &filter_fun, &total_cnt, &bytes_read, &used_cnt, this](const char* buffer_process, size_t read_cnt) { + size_t cnt = 0; + size_t i = 0; + size_t last_i = 0; + INDEX_T start_idx = used_cnt; + // skip the break between \r and \n + if (last_line_.size() == 0 && buffer_process[0] == '\n') { + i = 1; + last_i = i; + } + while (i < read_cnt) { + if (buffer_process[i] == '\n' || buffer_process[i] == '\r') { + if (last_line_.size() > 0) { + last_line_.append(buffer_process + last_i, i - last_i); + if (filter_fun(used_cnt, total_cnt)) { + lines_.push_back(last_line_); + ++used_cnt; + } + last_line_ = ""; + } else { + if (filter_fun(used_cnt, total_cnt)) { + lines_.emplace_back(buffer_process + last_i, i - last_i); + ++used_cnt; + } + } + ++cnt; + ++i; + ++total_cnt; + // skip end of line + while ((buffer_process[i] == '\n' || buffer_process[i] == '\r') && i < read_cnt) { + ++i; + } + last_i = i; + } else { + ++i; + } + } + process_fun(start_idx, lines_); + lines_.clear(); + if (last_i != read_cnt) { + last_line_.append(buffer_process + last_i, read_cnt - last_i); + } + + size_t prev_bytes_read = bytes_read; + bytes_read += read_cnt; + if (prev_bytes_read / read_progress_interval_bytes_ < bytes_read / read_progress_interval_bytes_) { + Log::Debug("Read %.1f GBs from %s.", 1.0 * bytes_read / kGbs, filename_); + } + + return cnt; + }); // if last line of file doesn't contain end of line if (last_line_.size() > 0) { Log::Info("Warning: last line of %s has no end of line, still using this line", filename_); @@ -718,13 +733,13 @@ class TextReader { INDEX_T ReadPartAndProcessParallel(const std::vector& used_data_indices, const std::function&)>& process_fun) { return ReadAllAndProcessParallelWithFilter(process_fun, - [&used_data_indices](INDEX_T used_cnt, INDEX_T total_cnt) { - if (static_cast(used_cnt) < used_data_indices.size() && total_cnt == used_data_indices[used_cnt]) { - return true; - } else { - return false; - } - }); + [&used_data_indices](INDEX_T used_cnt, INDEX_T total_cnt) { + if (static_cast(used_cnt) < used_data_indices.size() && total_cnt == used_data_indices[used_cnt]) { + return true; + } else { + return false; + } + }); } private: @@ -745,4 +760,4 @@ class TextReader { } // namespace StochTree -#endif // STOCHTREE_IO_H_ +#endif // STOCHTREE_IO_H_ diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 6f42c110..b40156e6 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -383,8 +383,8 @@ class GaussianConstantSuffStat { void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; if (dataset.HasVarWeights()) { - sum_w += 1/dataset.VarWeightValue(row_idx); - sum_yw += outcome(row_idx, 0)/dataset.VarWeightValue(row_idx); + sum_w += 1 / dataset.VarWeightValue(row_idx); + sum_yw += outcome(row_idx, 0) / dataset.VarWeightValue(row_idx); } else { sum_w += 1.0; sum_yw += outcome(row_idx, 0); @@ -462,7 +462,10 @@ class GaussianConstantLeafModel { * * \param tau Leaf node prior scale parameter */ - GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} + GaussianConstantLeafModel(double tau) { + tau_ = tau; + normal_sampler_ = UnivariateNormalSampler(); + } ~GaussianConstantLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. @@ -511,11 +514,12 @@ class GaussianConstantLeafModel { * * \param tau Leaf node prior scale parameter */ - void SetScale(double tau) {tau_ = tau;} + void SetScale(double tau) { tau_ = tau; } /*! * \brief Whether this model requires a basis vector for posterior inference and prediction */ - inline bool RequiresBasis() {return false;} + inline bool RequiresBasis() { return false; } + private: double tau_; UnivariateNormalSampler normal_sampler_; @@ -547,11 +551,11 @@ class GaussianUnivariateRegressionSuffStat { void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; if (dataset.HasVarWeights()) { - sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx); - sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx); + sum_xxw += dataset.BasisValue(row_idx, 0) * dataset.BasisValue(row_idx, 0) / dataset.VarWeightValue(row_idx); + sum_yxw += outcome(row_idx, 0) * dataset.BasisValue(row_idx, 0) / dataset.VarWeightValue(row_idx); } else { - sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0); - sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0); + sum_xxw += dataset.BasisValue(row_idx, 0) * dataset.BasisValue(row_idx, 0); + sum_yxw += outcome(row_idx, 0) * dataset.BasisValue(row_idx, 0); } } /*! @@ -562,7 +566,7 @@ class GaussianUnivariateRegressionSuffStat { sum_xxw = 0.0; sum_yxw = 0.0; } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics @@ -621,7 +625,10 @@ class GaussianUnivariateRegressionSuffStat { /*! \brief Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model */ class GaussianUnivariateRegressionLeafModel { public: - GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} + GaussianUnivariateRegressionLeafModel(double tau) { + tau_ = tau; + normal_sampler_ = UnivariateNormalSampler(); + } ~GaussianUnivariateRegressionLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. @@ -665,8 +672,9 @@ class GaussianUnivariateRegressionLeafModel { */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); - void SetScale(double tau) {tau_ = tau;} - inline bool RequiresBasis() {return true;} + void SetScale(double tau) { tau_ = tau; } + inline bool RequiresBasis() { return true; } + private: double tau_; UnivariateNormalSampler normal_sampler_; @@ -703,16 +711,16 @@ class GaussianMultivariateRegressionSuffStat { n += 1; if (dataset.HasVarWeights()) { for (int i = 0; i < p; i++) { - ytWX(0,i) += outcome(row_idx, 0) * dataset.BasisValue(row_idx, i) / dataset.VarWeightValue(row_idx); + ytWX(0, i) += outcome(row_idx, 0) * dataset.BasisValue(row_idx, i) / dataset.VarWeightValue(row_idx); for (int j = 0; j < p; j++) { - XtWX(i,j) += dataset.BasisValue(row_idx, i) * dataset.BasisValue(row_idx, j) / dataset.VarWeightValue(row_idx); + XtWX(i, j) += dataset.BasisValue(row_idx, i) * dataset.BasisValue(row_idx, j) / dataset.VarWeightValue(row_idx); } } } else { for (int i = 0; i < p; i++) { - ytWX(0,i) += outcome(row_idx, 0) * dataset.BasisValue(row_idx, i); + ytWX(0, i) += outcome(row_idx, 0) * dataset.BasisValue(row_idx, i); for (int j = 0; j < p; j++) { - XtWX(i,j) += dataset.BasisValue(row_idx, i) * dataset.BasisValue(row_idx, j); + XtWX(i, j) += dataset.BasisValue(row_idx, i) * dataset.BasisValue(row_idx, j); } } } @@ -729,7 +737,7 @@ class GaussianMultivariateRegressionSuffStat { } } } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics @@ -793,7 +801,10 @@ class GaussianMultivariateRegressionLeafModel { * * \param Sigma_0 Prior covariance, must have the same number of rows and columns as dimensions of the basis vector for the multivariate regression problem */ - GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();} + GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) { + Sigma_0_ = Sigma_0; + multivariate_normal_sampler_ = MultivariateNormalSampler(); + } ~GaussianMultivariateRegressionLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. @@ -837,8 +848,9 @@ class GaussianMultivariateRegressionLeafModel { */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); - void SetScale(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0;} - inline bool RequiresBasis() {return true;} + void SetScale(Eigen::MatrixXd& Sigma_0) { Sigma_0_ = Sigma_0; } + inline bool RequiresBasis() { return true; } + private: Eigen::MatrixXd Sigma_0_; MultivariateNormalSampler multivariate_normal_sampler_; @@ -864,7 +876,7 @@ class LogLinearVarianceSuffStat { */ void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; - weighted_sum_ei += std::exp(std::log(outcome(row_idx)*outcome(row_idx)) - tracker.GetSamplePrediction(row_idx) + tracker.GetTreeSamplePrediction(row_idx, tree_idx)); + weighted_sum_ei += std::exp(std::log(outcome(row_idx) * outcome(row_idx)) - tracker.GetSamplePrediction(row_idx) + tracker.GetTreeSamplePrediction(row_idx, tree_idx)); } /*! * \brief Reset all of the sufficient statistics to zero @@ -929,7 +941,11 @@ class LogLinearVarianceSuffStat { /*! \brief Marginal likelihood and posterior computation for heteroskedastic log-linear variance model */ class LogLinearVarianceLeafModel { public: - LogLinearVarianceLeafModel(double a, double b) {a_ = a; b_ = b; gamma_sampler_ = GammaSampler();} + LogLinearVarianceLeafModel(double a, double b) { + a_ = a; + b_ = b; + gamma_sampler_ = GammaSampler(); + } ~LogLinearVarianceLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. @@ -974,16 +990,16 @@ class LogLinearVarianceLeafModel { */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); - void SetPriorShape(double a) {a_ = a;} - void SetPriorRate(double b) {b_ = b;} - inline bool RequiresBasis() {return false;} + void SetPriorShape(double a) { a_ = a; } + void SetPriorRate(double b) { b_ = b; } + inline bool RequiresBasis() { return false; } + private: double a_; double b_; GammaSampler gamma_sampler_; }; - /*! \brief Sufficient statistic and associated operations for complementary log-log ordinal BART model */ class CloglogOrdinalSuffStat { public: @@ -1016,20 +1032,20 @@ class CloglogOrdinalSuffStat { unsigned int y = static_cast(outcome(row_idx)); // Get auxiliary data from tracker (assuming types: 0=latents Z, 1=forest predictions, 2=cutpoints gamma, 3=cumsum exp of gamma) - double Z = dataset.GetAuxiliaryDataValue(0, row_idx); // latent variables Z + double Z = dataset.GetAuxiliaryDataValue(0, row_idx); // latent variables Z double lambda_minus = dataset.GetAuxiliaryDataValue(1, row_idx); // forest predictions excluding current tree // Get cutpoints gamma and cumulative sum of exp(gamma) const std::vector& gamma = dataset.GetAuxiliaryDataVectorConst(2); // cutpoints gamma const std::vector& seg = dataset.GetAuxiliaryDataVectorConst(3); // cumsum exp of gamma - + int K = gamma.size() + 1; // Number of ordinal categories if (y == K - 1) { other_sum += std::exp(lambda_minus) * seg[y]; // checked and it's correct } else { sum_Y_less_K += 1.0; - other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]); // checked and it's correct + other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]); // checked and it's correct } } @@ -1150,7 +1166,7 @@ class CloglogOrdinalLeafModel { * Samples from log-gamma: sample from gamma, then take log. */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); - inline bool RequiresBasis() {return false;} + inline bool RequiresBasis() { return false; } private: double a_; @@ -1185,12 +1201,12 @@ using LeafModelVariant = std::variant; -template +template static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) { return SuffStatType(leaf_suff_stat_args...); } -template +template static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_model_args) { return LeafModelType(leaf_model_args...); } @@ -1238,12 +1254,11 @@ static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau } } -template +template static inline void AccumulateSuffStatProposed( - SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, - ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads, - SuffStatConstructorArgs&... suff_stat_args -) { + SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, + ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads, + SuffStatConstructorArgs&... suff_stat_args) { // Determine the position of the node's indices in the forest tracking data structure int node_begin_index = tracker.UnsortedNodeBegin(tree_num, leaf_num); int node_end_index = tracker.UnsortedNodeEnd(tree_num, leaf_num); @@ -1309,9 +1324,9 @@ static inline void AccumulateSuffStatProposed( } } -template +template static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, - ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) { + ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) { // Acquire iterators auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id); auto left_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, left_node_id); @@ -1333,7 +1348,7 @@ static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, Suff } } -template +template static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, int tree_num, int node_id) { // Acquire iterators std::vector::iterator node_begin_iter; @@ -1354,10 +1369,10 @@ static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, Fo } } -template +template static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, - ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, - int feature_num, int cutpoint_num) { + ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, + int feature_num, int cutpoint_num) { // Acquire iterators auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num); auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num); @@ -1382,8 +1397,8 @@ static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, F } } -/*! \} */ // end of leaf_model_group +/*! \} */ // end of leaf_model_group -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_LEAF_MODEL_H_ +#endif // STOCHTREE_LEAF_MODEL_H_ diff --git a/include/stochtree/log.h b/include/stochtree/log.h index 9f64c31b..8ce87f79 100644 --- a/include/stochtree/log.h +++ b/include/stochtree/log.h @@ -1,9 +1,9 @@ /*! * Logging and runtime value checking utilities. - * - * This code is largely included as-is from LightGBM, which carries + * + * This code is largely included as-is from LightGBM, which carries * the following copyright information: - * + * * Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. @@ -65,10 +65,10 @@ namespace StochTree { #endif #ifndef CHECK_NOTNULL -#define CHECK_NOTNULL(pointer) \ - if ((pointer) == nullptr) \ +#define CHECK_NOTNULL(pointer) \ + if ((pointer) == nullptr) \ StochTree::Log::Fatal(#pointer " Can't be NULL at %s, line %d .\n", \ - __FILE__, __LINE__); + __FILE__, __LINE__); #endif enum class LogLevel : int { @@ -83,7 +83,7 @@ enum class LogLevel : int { */ class Log { public: - using Callback = void (*)(const char *); + using Callback = void (*)(const char*); /*! * \brief Resets the minimal log level. It is INFO by default. * \param level The new minimal log level. @@ -92,25 +92,25 @@ class Log { static void ResetCallBack(Callback callback) { GetLogCallBack() = callback; } - static void Debug(const char *format, ...) { + static void Debug(const char* format, ...) { va_list val; va_start(val, format); Write(LogLevel::Debug, "Debug", format, val); va_end(val); } - static void Info(const char *format, ...) { + static void Info(const char* format, ...) { va_list val; va_start(val, format); Write(LogLevel::Info, "Info", format, val); va_end(val); } - static void Warning(const char *format, ...) { + static void Warning(const char* format, ...) { va_list val; va_start(val, format); Write(LogLevel::Warning, "Warning", format, val); va_end(val); } - static void Fatal(const char *format, ...) { + static void Fatal(const char* format, ...) { va_list val; const size_t kBufSize = 1024; char str_buf[kBufSize]; @@ -135,7 +135,7 @@ class Log { } private: - static void Write(LogLevel level, const char *level_str, const char *format, + static void Write(LogLevel level, const char* level_str, const char* format, va_list val) { if (level <= GetLevel()) { // omit the message with low level // R code should write back to R's output stream, @@ -166,12 +166,12 @@ class Log { // a trick to use static variable in header file. // May be not good, but avoid to use an additional cpp file - static LogLevel &GetLevel() { + static LogLevel& GetLevel() { static THREAD_LOCAL LogLevel level = LogLevel::Info; return level; } - static Callback &GetLogCallBack() { + static Callback& GetLogCallBack() { static THREAD_LOCAL Callback callback = nullptr; return callback; } diff --git a/include/stochtree/mainpage.h b/include/stochtree/mainpage.h index dc39f162..09cfdc28 100644 --- a/include/stochtree/mainpage.h +++ b/include/stochtree/mainpage.h @@ -3,80 +3,80 @@ /*! * \mainpage stochtree C++ Documentation - * + * * \section getting-started Getting Started - * + * * `stochtree` can be built and run as a standalone C++ program directly from source using `cmake`: - * + * * \subsection cloning-repo Cloning the Repository - * - * To clone the repository, you must have git installed, which you can do following these instructions. - * - * Once git is available at the command line, navigate to the folder that will store this project (in bash / zsh, this is done by running `cd` followed by the path to the directory). + * + * To clone the repository, you must have git installed, which you can do following these instructions. + * + * Once git is available at the command line, navigate to the folder that will store this project (in bash / zsh, this is done by running `cd` followed by the path to the directory). * Then, clone the `stochtree` repo as a subfolder by running * \code{.sh} * git clone --recursive https://github.com/StochasticTree/stochtree.git * \endcode - * - * NOTE: this project incorporates several dependencies as git submodules, - * which is why the `--recursive` flag is necessary (some systems may perform a recursive clone without this flag, but - * `--recursive` ensures this behavior on all platforms). If you have already cloned the repo without the `--recursive` flag, + * + * NOTE: this project incorporates several dependencies as git submodules, + * which is why the `--recursive` flag is necessary (some systems may perform a recursive clone without this flag, but + * `--recursive` ensures this behavior on all platforms). If you have already cloned the repo without the `--recursive` flag, * you can retrieve the submodules recursively by running `git submodule update --init --recursive` in the main repo directory. - * + * * \section key-components Key Components - * + * * The stochtree C++ core consists of thousands of lines of C++ code, but it can organized and understood through several components (see [topics](topics.html) for more detail): - * + * * - Trees: the most important "primitive" of decision tree algorithms is the \ref tree_group "decision tree itself", which in stochtree is defined by a \ref StochTree::Tree "Tree" class as well as a series of static helper functions for prediction. * - Forest: individual trees are combined into a \ref forest_group "forest", or ensemble, which in stochtree is defined by the \ref StochTree::TreeEnsemble "TreeEnsemble" class and a container of forests is defined by the \ref StochTree::ForestContainer "ForestContainer" class. * - Dataset: data can be loaded from a variety of sources into a `stochtree` \ref data_group "data layer". * - Leaf Model: `stochtree`'s data structures are generalized to support a wide range of models, which are defined via specialized classes in the \ref leaf_model_group "leaf model layer". * - Sampler: helper functions that sample forests from training data comprise the \ref sampling_group "sampling layer" of `stochtree`. - * + * * \section extending-stochtree Extending stochtree - * + * * \subsection custom-leaf-models Custom Leaf Models - * - * The \ref leaf_model_group "leaf model documentation" details the key components of new decision tree models: - * custom `LeafModel` and `SuffStat` classes that implement a model's log marginal likelihood and posterior computations. - * - * Adding a new leaf model will consist largely of implementing new versions of each of these classes which track the - * API of the existing classes. Once these classes exist, they need to be reflected in several places. - * + * + * The \ref leaf_model_group "leaf model documentation" details the key components of new decision tree models: + * custom `LeafModel` and `SuffStat` classes that implement a model's log marginal likelihood and posterior computations. + * + * Adding a new leaf model will consist largely of implementing new versions of each of these classes which track the + * API of the existing classes. Once these classes exist, they need to be reflected in several places. + * * Suppose, for the sake of illustration, that the newest custom leaf model is a multinomial logit model. - * + * * First, add an entry to the \ref StochTree::ModelType "ModelType" enumeration for this new model type - * + * * \code{.cpp} * enum ModelType { - * kConstantLeafGaussian, - * kUnivariateRegressionLeafGaussian, - * kMultivariateRegressionLeafGaussian, - * kLogLinearVariance, - * kMultinomialLogit, + * kConstantLeafGaussian, + * kUnivariateRegressionLeafGaussian, + * kMultivariateRegressionLeafGaussian, + * kLogLinearVariance, + * kMultinomialLogit, * }; - * \endcode - * + * \endcode + * * Next, add entries to the `std::variants` that bundle related `SuffStat` and `LeafModel` classes - * + * * \code{.cpp} - * using SuffStatVariant = std::variant; - * \endcode - * + * \endcode + * * \code{.cpp} - * using LeafModelVariant = std::variant; - * \endcode - * + * \endcode + * * Finally, update the \ref StochTree::suffStatFactory "suffStatFactory" and \ref StochTree::leafModelFactory "leafModelFactory" functions to add a logic branch registering these new objects - * + * * \code{.cpp} * static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) { * if (model_type == kConstantLeafGaussian) { @@ -93,8 +93,8 @@ * Log::Fatal("Incompatible model type provided to suff stat factory"); * } * } - * \endcode - * + * \endcode + * * \code{.cpp} * static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) { * if (model_type == kConstantLeafGaussian) { @@ -111,8 +111,8 @@ * Log::Fatal("Incompatible model type provided to leaf model factory"); * } * } - * \endcode - * + * \endcode + * */ #endif // STOCHTREE_MAINPAGE_H_ diff --git a/include/stochtree/meta.h b/include/stochtree/meta.h index d0aa4049..1a8edc78 100644 --- a/include/stochtree/meta.h +++ b/include/stochtree/meta.h @@ -1,9 +1,9 @@ /*! * Macros, constants, and type definitions used elsewhere in the codebase - * - * This code is largely included as-is from LightGBM, which carries + * + * This code is largely included as-is from LightGBM, which carries * the following copyright information: - * + * * Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ @@ -20,20 +20,22 @@ #include #if (defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_AMD64))) || defined(__INTEL_COMPILER) || MM_PREFETCH - #include - #define PREFETCH_T0(addr) _mm_prefetch(reinterpret_cast(addr), _MM_HINT_T0) +#include +#define PREFETCH_T0(addr) _mm_prefetch(reinterpret_cast(addr), _MM_HINT_T0) #elif defined(__GNUC__) - #define PREFETCH_T0(addr) __builtin_prefetch(reinterpret_cast(addr), 0, 3) +#define PREFETCH_T0(addr) __builtin_prefetch(reinterpret_cast(addr), 0, 3) #else - #define PREFETCH_T0(addr) do {} while (0) +#define PREFETCH_T0(addr) \ + do { \ + } while (0) #endif namespace StochTree { /*! \brief Integer encoding of feature types */ enum FeatureType { - kNumeric, /*!< Numeric feature */ - kOrderedCategorical, /*!< Ordered categorical feature */ + kNumeric, /*!< Numeric feature */ + kOrderedCategorical, /*!< Ordered categorical feature */ kUnorderedCategorical /*!< Unordered categorical feature */ }; @@ -123,31 +125,29 @@ typedef int32_t node_t; typedef double split_cond_t; using PredictFunction = -std::function>&, double* output)>; + std::function>&, double* output)>; using PredictSparseFunction = -std::function>&, std::vector>* output)>; + std::function>&, std::vector>* output)>; -typedef void(*ReduceFunction)(const char* input, char* output, int type_size, comm_size_t array_size); +typedef void (*ReduceFunction)(const char* input, char* output, int type_size, comm_size_t array_size); +typedef void (*ReduceScatterFunction)(char* input, comm_size_t input_size, int type_size, + const comm_size_t* block_start, const comm_size_t* block_len, int num_block, char* output, comm_size_t output_size, + const ReduceFunction& reducer); -typedef void(*ReduceScatterFunction)(char* input, comm_size_t input_size, int type_size, - const comm_size_t* block_start, const comm_size_t* block_len, int num_block, char* output, comm_size_t output_size, - const ReduceFunction& reducer); - -typedef void(*AllgatherFunction)(char* input, comm_size_t input_size, const comm_size_t* block_start, - const comm_size_t* block_len, int num_block, char* output, comm_size_t output_size); - +typedef void (*AllgatherFunction)(char* input, comm_size_t input_size, const comm_size_t* block_start, + const comm_size_t* block_len, int num_block, char* output, comm_size_t output_size); #define NO_SPECIFIC (-1) const int kAlignedSize = 32; -#define SIZE_ALIGNED(t) ((t) + kAlignedSize - 1) / kAlignedSize * kAlignedSize +#define SIZE_ALIGNED(t) ((t) + kAlignedSize - 1) / kAlignedSize* kAlignedSize // Refer to https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-4-c4127?view=vs-2019 #ifdef _MSC_VER - #pragma warning(disable : 4127) +#pragma warning(disable : 4127) #endif } // namespace StochTree diff --git a/include/stochtree/normal_sampler.h b/include/stochtree/normal_sampler.h index 60b8a550..bca02079 100644 --- a/include/stochtree/normal_sampler.h +++ b/include/stochtree/normal_sampler.h @@ -12,11 +12,12 @@ namespace StochTree { class UnivariateNormalSampler { public: - UnivariateNormalSampler() {std_normal_dist_ = standard_normal();} + UnivariateNormalSampler() { std_normal_dist_ = standard_normal(); } ~UnivariateNormalSampler() {} double Sample(double mean, double variance, std::mt19937& gen) { return mean + std::sqrt(variance) * std_normal_dist_(gen); } + private: /*! \brief Standard normal distribution */ standard_normal std_normal_dist_; @@ -24,7 +25,7 @@ class UnivariateNormalSampler { class MultivariateNormalSampler { public: - MultivariateNormalSampler() {std_normal_dist_ = standard_normal();} + MultivariateNormalSampler() { std_normal_dist_ = standard_normal(); } ~MultivariateNormalSampler() {} std::vector Sample(Eigen::VectorXd& mean, Eigen::MatrixXd& covariance, std::mt19937& gen) { // Dimension extraction and checks @@ -32,7 +33,7 @@ class MultivariateNormalSampler { int cov_rows = covariance.rows(); int cov_cols = covariance.cols(); CHECK_EQ(mean_cols, cov_cols); - + // Variance cholesky decomposition Eigen::LLT decomposition(covariance); Eigen::MatrixXd covariance_chol = decomposition.matrixL(); @@ -57,7 +58,7 @@ class MultivariateNormalSampler { int cov_rows = covariance.rows(); int cov_cols = covariance.cols(); CHECK_EQ(mean_cols, cov_cols); - + // Variance cholesky decomposition Eigen::LLT decomposition(covariance); Eigen::MatrixXd covariance_chol = decomposition.matrixL(); @@ -71,11 +72,12 @@ class MultivariateNormalSampler { // Compute and return the sampled value return mean + covariance_chol * std_norm_vec; } + private: /*! \brief Standard normal distribution */ standard_normal std_normal_dist_; }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_NORMAL_SAMPLER_H_ \ No newline at end of file +#endif // STOCHTREE_NORMAL_SAMPLER_H_ \ No newline at end of file diff --git a/include/stochtree/openmp_utils.h b/include/stochtree/openmp_utils.h index 78c17234..e3add212 100644 --- a/include/stochtree/openmp_utils.h +++ b/include/stochtree/openmp_utils.h @@ -13,43 +13,43 @@ namespace StochTree { // OpenMP thread management inline int get_max_threads() { - return omp_get_max_threads(); + return omp_get_max_threads(); } inline int get_thread_num() { - return omp_get_thread_num(); + return omp_get_thread_num(); } inline int get_num_threads() { - return omp_get_num_threads(); + return omp_get_num_threads(); } inline void set_num_threads(int num_threads) { - omp_set_num_threads(num_threads); + omp_set_num_threads(num_threads); } - + #define STOCHTREE_PARALLEL_FOR(num_threads) \ - _Pragma("omp parallel for num_threads(num_threads)") + _Pragma("omp parallel for num_threads(num_threads)") #define STOCHTREE_REDUCTION_ADD(var) \ - _Pragma("omp reduction(+:var)") + _Pragma("omp reduction(+:var)") #define STOCHTREE_CRITICAL \ - _Pragma("omp critical") + _Pragma("omp critical") #else #define STOCHTREE_HAS_OPENMP 0 -inline int get_max_threads() {return 1;} +inline int get_max_threads() { return 1; } -inline int get_thread_num() {return 0;} +inline int get_thread_num() { return 0; } -inline int get_num_threads() {return 1;} +inline int get_num_threads() { return 1; } inline void set_num_threads(int num_threads) {} - + #define STOCHTREE_PARALLEL_FOR(num_threads) - + #define STOCHTREE_REDUCTION_ADD(var) #define STOCHTREE_CRITICAL @@ -57,57 +57,57 @@ inline void set_num_threads(int num_threads) {} #endif static int GetMaxThreads() { - return get_max_threads(); + return get_max_threads(); } static int GetCurrentThreadNum() { - return get_thread_num(); + return get_thread_num(); } - + static int GetNumThreads() { - return get_num_threads(); + return get_num_threads(); } - + static void SetNumThreads(int num_threads) { - set_num_threads(num_threads); + set_num_threads(num_threads); } - + static bool IsOpenMPAvailable() { - return STOCHTREE_HAS_OPENMP; + return STOCHTREE_HAS_OPENMP; } - + static int GetOptimalThreadCount(int workload_size, int min_work_per_thread = 1000) { - if (!IsOpenMPAvailable()) { - return 1; - } - - int max_threads = GetMaxThreads(); - int optimal_threads = workload_size / min_work_per_thread; - - return std::min(optimal_threads, max_threads); + if (!IsOpenMPAvailable()) { + return 1; + } + + int max_threads = GetMaxThreads(); + int optimal_threads = workload_size / min_work_per_thread; + + return std::min(optimal_threads, max_threads); } // Parallel execution utilities -template +template void ParallelFor(int start, int end, int num_threads, Func func) { - if (num_threads <= 0) { - num_threads = GetOptimalThreadCount(end - start); + if (num_threads <= 0) { + num_threads = GetOptimalThreadCount(end - start); + } + + if (num_threads == 1 || !STOCHTREE_HAS_OPENMP) { + // Sequential execution + for (int i = start; i < end; ++i) { + func(i); } - - if (num_threads == 1 || !STOCHTREE_HAS_OPENMP) { - // Sequential execution - for (int i = start; i < end; ++i) { - func(i); - } - } else { - // Parallel execution - STOCHTREE_PARALLEL_FOR(num_threads) - for (int i = start; i < end; ++i) { - func(i); - } + } else { + // Parallel execution + STOCHTREE_PARALLEL_FOR(num_threads) + for (int i = start; i < end; ++i) { + func(i); } + } } -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_OPENMP_UTILS_H \ No newline at end of file +#endif // STOCHTREE_OPENMP_UTILS_H \ No newline at end of file diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h index d67563e2..b4269987 100644 --- a/include/stochtree/ordinal_sampler.h +++ b/include/stochtree/ordinal_sampler.h @@ -18,24 +18,24 @@ namespace StochTree { static double sample_truncated_exponential_low_high(double u, double rate, double low, double high) { - return -std::log((1-u)*std::exp(-rate*low) + u*std::exp(-rate*high))/rate; + return -std::log((1 - u) * std::exp(-rate * low) + u * std::exp(-rate * high)) / rate; } static double sample_truncated_exponential_low(double u, double rate, double low) { - return -std::log((1-u)*std::exp(-rate*low))/rate; + return -std::log((1 - u) * std::exp(-rate * low)) / rate; } static double sample_truncated_exponential_high(double u, double rate, double high) { - return -std::log1p(u*std::expm1(-high*rate))/rate; + return -std::log1p(u * std::expm1(-high * rate)) / rate; } static double sample_exponential(double u, double rate) { - return -std::log1p(-u)/rate; + return -std::log1p(-u) / rate; } /*! * \brief Sampler for ordinal model hyperparameters - * + * * This class handles MCMC sampling for ordinal-specific parameters: * - Truncated exponential latent variables (Z) * - Cutpoint parameters (gamma) @@ -50,9 +50,9 @@ class OrdinalSampler { /*! * \brief Sample from truncated exponential distribution - * + * * Samples from exponential distribution truncated to [low,high] - * + * * \param gen Random number generator * \param rate Rate parameter for exponential distribution * \param low Lower truncation bound @@ -63,7 +63,7 @@ class OrdinalSampler { /*! * \brief Update truncated exponential latent variables (Z) - * + * * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling * \param outcome Vector of outcome values * \param gen Random number generator @@ -72,7 +72,7 @@ class OrdinalSampler { /*! * \brief Update gamma cutpoint parameters - * + * * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling * \param outcome Vector of outcome values * \param alpha_gamma Shape parameter for log-gamma prior on cutpoints gamma @@ -80,13 +80,13 @@ class OrdinalSampler { * \param gamma_0 Fixed value for first cutpoint parameter (for identifiability) * \param gen Random number generator */ - void UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, - double alpha_gamma, double beta_gamma, + void UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, + double alpha_gamma, double beta_gamma, double gamma_0, std::mt19937& gen); /*! * \brief Update cumulative exponential sums (seg) - * + * * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling */ void UpdateCumulativeExpSums(ForestDataset& dataset); @@ -95,6 +95,6 @@ class OrdinalSampler { GammaSampler gamma_sampler_; }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_ORDINAL_SAMPLER_H_ +#endif // STOCHTREE_ORDINAL_SAMPLER_H_ diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index f25c875c..a0247884 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -1,25 +1,25 @@ /*! * Copyright (c) 2024 stochtree authors. - * + * * Data structures used for tracking dataset through the tree building process. - * + * * The first category of data structure tracks observations available in nodes of a tree. - * a. UnsortedNodeSampleTracker tracks the observations available in every leaf of every tree in an ensemble, + * a. UnsortedNodeSampleTracker tracks the observations available in every leaf of every tree in an ensemble, * in no feature-specific sort order. It is primarily designed for use in BART-based algorithms. - * b. SortedNodeSampleTracker tracks the observations available in a every leaf of a tree, pre-sorted + * b. SortedNodeSampleTracker tracks the observations available in a every leaf of a tree, pre-sorted * separately for each feature. It is primarily designed for use in XBART-based algorithms. - * + * * The second category, SampleNodeMapper, maps observations from a dataset to leaf nodes. - * - * SampleNodeMapper is inspired by the design of the DataPartition class in LightGBM, + * + * SampleNodeMapper is inspired by the design of the DataPartition class in LightGBM, * released under the MIT license with the following copyright: - * + * * Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. - * - * SortedNodeSampleTracker is inspired by the "approximate" split finding method in xgboost, released + * + * SortedNodeSampleTracker is inspired by the "approximate" split finding method in xgboost, released * under the Apache license with the following copyright: - * + * * Copyright 2015~2023 by XGBoost Contributors */ #ifndef STOCHTREE_PARTITION_TRACKER_H_ @@ -47,7 +47,7 @@ class ForestTracker { public: /*! * \brief Construct a new `ForestTracker` object - * + * * \param covariates Matrix of covariate data * \param feature_types Type of each feature (column) in `covariates`. This is represented by the enum `StochTree::FeatureType` * \param num_trees Number of trees in an ensemble to be sampled @@ -83,14 +83,15 @@ class ForestTracker { std::vector::iterator UnsortedNodeEndIterator(int tree_id, int node_id); std::vector::iterator SortedNodeBeginIterator(int node_id, int feature_id); std::vector::iterator SortedNodeEndIterator(int node_id, int feature_id); - SamplePredMapper* GetSamplePredMapper() {return sample_pred_mapper_.get();} - SampleNodeMapper* GetSampleNodeMapper() {return sample_node_mapper_.get();} - UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() {return unsorted_node_sample_tracker_.get();} - SortedNodeSampleTracker* GetSortedNodeSampleTracker() {return sorted_node_sample_tracker_.get();} - int GetNumObservations() {return num_observations_;} - int GetNumTrees() {return num_trees_;} - int GetNumFeatures() {return num_features_;} - bool Initialized() {return initialized_;} + SamplePredMapper* GetSamplePredMapper() { return sample_pred_mapper_.get(); } + SampleNodeMapper* GetSampleNodeMapper() { return sample_node_mapper_.get(); } + UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() { return unsorted_node_sample_tracker_.get(); } + SortedNodeSampleTracker* GetSortedNodeSampleTracker() { return sorted_node_sample_tracker_.get(); } + int GetNumObservations() { return num_observations_; } + int GetNumTrees() { return num_trees_; } + int GetNumFeatures() { return num_features_; } + bool Initialized() { return initialized_; } + private: /*! \brief Mapper from observations to predicted values summed over every tree in a forest */ std::vector sum_predictions_; @@ -102,7 +103,7 @@ class ForestTracker { * Primarily used in MCMC algorithms */ std::unique_ptr unsorted_node_sample_tracker_; - /*! \brief Data structure tracking / updating observations available in each node for each feature (pre-sorted) for a given tree in a forest + /*! \brief Data structure tracking / updating observations available in each node for each feature (pre-sorted) for a given tree in a forest * Primarily used in GFR algorithms */ std::unique_ptr presort_container_; @@ -145,10 +146,10 @@ class SamplePredMapper { CHECK_LT(tree_id, num_trees_); tree_preds_[tree_id][sample_id] = value; } - - inline int NumTrees() {return num_trees_;} - - inline int NumObservations() {return num_observations_;} + + inline int NumTrees() { return num_trees_; } + + inline int NumObservations() { return num_observations_; } inline void AssignAllSamplesToConstantPrediction(int tree_id, double value) { for (data_size_t i = 0; i < num_observations_; i++) { @@ -174,8 +175,8 @@ class SampleNodeMapper { tree_observation_indices_[j].resize(num_observations_); } } - - SampleNodeMapper(SampleNodeMapper& other){ + + SampleNodeMapper(SampleNodeMapper& other) { num_trees_ = other.NumTrees(); num_observations_ = other.NumObservations(); // Initialize the vector of vectors of leaf indices for each tree @@ -214,10 +215,10 @@ class SampleNodeMapper { CHECK_LT(tree_id, num_trees_); tree_observation_indices_[tree_id][sample_id] = node_id; } - - inline int NumTrees() {return num_trees_;} - - inline int NumObservations() {return num_observations_;} + + inline int NumTrees() { return num_trees_; } + + inline int NumObservations() { return num_observations_; } inline void AssignAllSamplesToRoot(int tree_id) { for (data_size_t i = 0; i < num_observations_; i++) { @@ -333,10 +334,11 @@ class UnsortedNodeSampleTracker { void PartitionTreeNode(Eigen::MatrixXd& covariates, int tree_id, int node_id, int left_node_id, int right_node_id, int feature_split, std::vector const& category_list) { return feature_partitions_[tree_id]->PartitionNode(covariates, node_id, left_node_id, right_node_id, feature_split, category_list); } - + /*! \brief Convert a tree to root */ void ResetTreeToRoot(int tree_id, data_size_t n) { - feature_partitions_[tree_id].reset(new FeatureUnsortedPartition(n));; + feature_partitions_[tree_id].reset(new FeatureUnsortedPartition(n)); + ; } /*! \brief Convert a (currently split) node to a leaf */ @@ -447,15 +449,15 @@ class NodeOffsetSize { ~NodeOffsetSize() {} - void SetSorted() {presorted_ = true;} + void SetSorted() { presorted_ = true; } - bool IsSorted() {return presorted_;} + bool IsSorted() { return presorted_; } - data_size_t Begin() {return node_begin_;} + data_size_t Begin() { return node_begin_; } - data_size_t End() {return node_end_;} + data_size_t End() { return node_end_; } - data_size_t Size() {return node_size_;} + data_size_t Size() { return node_size_; } private: data_size_t node_begin_; @@ -468,16 +470,17 @@ class NodeOffsetSize { class FeaturePresortPartition; /*! \brief Data structure for presorting a feature by its values - * - * This class is intended to be run *once* on a dataset as it + * + * This class is intended to be run *once* on a dataset as it * pre-sorts each feature across the entire dataset. - * + * * FeaturePresortPartition is intended for use in recursive construction - * of new trees, and each new tree's FeaturePresortPartition is initialized + * of new trees, and each new tree's FeaturePresortPartition is initialized * from a FeaturePresortRoot class so that features are only arg-sorted one time. */ class FeaturePresortRoot { - friend FeaturePresortPartition; + friend FeaturePresortPartition; + public: FeaturePresortRoot(Eigen::MatrixXd& covariates, int32_t feature_index, FeatureType feature_type) { feature_index_ = feature_index; @@ -488,17 +491,17 @@ class FeaturePresortRoot { void ArgsortRoot(Eigen::MatrixXd& covariates) { data_size_t num_obs = covariates.rows(); - + // Make a vector of indices from 0 to num_obs - 1 - if (feature_sort_indices_.size() != num_obs){ + if (feature_sort_indices_.size() != num_obs) { feature_sort_indices_.resize(num_obs, 0); } std::iota(feature_sort_indices_.begin(), feature_sort_indices_.end(), 0); // Define a custom comparator to be used with stable_sort: - // For every two indices l and r store as elements of `data_sort_indices_`, + // For every two indices l and r store as elements of `data_sort_indices_`, // compare them for sorting purposes by indexing the covariate's raw data with both l and r - auto comp_op = [&](size_t const &l, size_t const &r) { return std::less{}(covariates(l, feature_index_), covariates(r, feature_index_)); }; + auto comp_op = [&](size_t const& l, size_t const& r) { return std::less{}(covariates(l, feature_index_), covariates(r, feature_index_)); }; std::stable_sort(feature_sort_indices_.begin(), feature_sort_indices_.end(), comp_op); } @@ -520,21 +523,21 @@ class FeaturePresortRootContainer { ~FeaturePresortRootContainer() {} - FeaturePresortRoot* GetFeaturePresort(int feature_num) {return feature_presort_[feature_num].get(); } + FeaturePresortRoot* GetFeaturePresort(int feature_num) { return feature_presort_[feature_num].get(); } private: std::vector> feature_presort_; int num_features_; }; -/*! \brief Data structure that tracks pre-sorted feature values +/*! \brief Data structure that tracks pre-sorted feature values * through a tree's split lifecycle - * - * This class is initialized from a FeaturePresortRoot which has computed the - * sort indices for a given feature over the entire dataset, so that sorting + * + * This class is initialized from a FeaturePresortRoot which has computed the + * sort indices for a given feature over the entire dataset, so that sorting * is not necessary for each new tree. - * - * When a split is made, this class handles sifting for each feature, so that + * + * When a split is made, this class handles sifting for each feature, so that * the presorted feature values available at each node are easily queried. */ class FeaturePresortPartition { @@ -563,28 +566,29 @@ class FeaturePresortPartition { void SplitFeatureCategorical(Eigen::MatrixXd& covariates, int32_t node_id, int32_t feature_index, std::vector const& category_list); /*! \brief Start position of node indexed by node_id */ - data_size_t NodeBegin(int32_t node_id) {return node_offset_sizes_[node_id].Begin();} + data_size_t NodeBegin(int32_t node_id) { return node_offset_sizes_[node_id].Begin(); } /*! \brief End position of node indexed by node_id */ - data_size_t NodeEnd(int32_t node_id) {return node_offset_sizes_[node_id].End();} + data_size_t NodeEnd(int32_t node_id) { return node_offset_sizes_[node_id].End(); } /*! \brief Size (in observations) of node indexed by node_id */ - data_size_t NodeSize(int32_t node_id) {return node_offset_sizes_[node_id].Size();} + data_size_t NodeSize(int32_t node_id) { return node_offset_sizes_[node_id].Size(); } /*! \brief Data indices for a given node */ std::vector NodeIndices(int node_id); /*! \brief Feature sort index j */ - data_size_t SortIndex(data_size_t j) {return feature_sort_indices_[j];} + data_size_t SortIndex(data_size_t j) { return feature_sort_indices_[j]; } /*! \brief Feature type */ - FeatureType GetFeatureType() {return feature_type_;} + FeatureType GetFeatureType() { return feature_type_; } /*! \brief Update SampleNodeMapper for all the observations in node_id */ void UpdateObservationMapping(int node_id, int tree_id, SampleNodeMapper* sample_node_mapper); /*! \brief Feature sort indices */ std::vector feature_sort_indices_; + private: /*! \brief Add left and right nodes */ void AddLeftRightNodes(data_size_t left_node_begin, data_size_t left_node_size, data_size_t right_node_begin, data_size_t right_node_size); @@ -663,7 +667,7 @@ class SortedNodeSampleTracker { } /*! \brief Feature sort index j for feature_index */ - data_size_t SortIndex(data_size_t j, int feature_index) {return feature_partitions_[feature_index]->SortIndex(j); } + data_size_t SortIndex(data_size_t j, int feature_index) { return feature_partitions_[feature_index]->SortIndex(j); } /*! \brief Update SampleNodeMapper for all the observations in node_id */ void UpdateObservationMapping(int node_id, int tree_id, SampleNodeMapper* sample_node_mapper, int feature_index = 0) { @@ -675,6 +679,6 @@ class SortedNodeSampleTracker { int num_features_; }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_PARTITION_TRACKER_H_ +#endif // STOCHTREE_PARTITION_TRACKER_H_ diff --git a/include/stochtree/prior.h b/include/stochtree/prior.h index 5d8686f7..3e1117e7 100644 --- a/include/stochtree/prior.h +++ b/include/stochtree/prior.h @@ -25,18 +25,19 @@ class RandomEffectsRegressionGaussianPrior : public RandomEffectsGaussianPrior { num_groups_ = num_groups; } ~RandomEffectsRegressionGaussianPrior() {} - double GetPriorVarianceShape() {return a_;} - double GetPriorVarianceScale() {return b_;} - int32_t GetNumComponents() {return num_components_;} - int32_t GetNumGroups() {return num_groups_;} - void SetPriorVarianceShape(double a) {a_ = a;} - void SetPriorVarianceScale(double b) {b_ = b;} - void SetNumComponents(int32_t num_components) {num_components_ = num_components;} - void SetNumGroups(int32_t num_groups) {num_groups_ = num_groups;} + double GetPriorVarianceShape() { return a_; } + double GetPriorVarianceScale() { return b_; } + int32_t GetNumComponents() { return num_components_; } + int32_t GetNumGroups() { return num_groups_; } + void SetPriorVarianceShape(double a) { a_ = a; } + void SetPriorVarianceScale(double b) { b_ = b; } + void SetNumComponents(int32_t num_components) { num_components_ = num_components; } + void SetNumGroups(int32_t num_groups) { num_groups_ = num_groups; } + private: double a_; double b_; - int32_t num_components_; + int32_t num_components_; int32_t num_groups_; }; @@ -49,14 +50,15 @@ class TreePrior { max_depth_ = max_depth; } ~TreePrior() {} - double GetAlpha() {return alpha_;} - double GetBeta() {return beta_;} - int32_t GetMinSamplesLeaf() {return min_samples_in_leaf_;} - int32_t GetMaxDepth() {return max_depth_;} - void SetAlpha(double alpha) {alpha_ = alpha;} - void SetBeta(double beta) {beta_ = beta;} - void SetMinSamplesLeaf(int32_t min_samples_in_leaf) {min_samples_in_leaf_ = min_samples_in_leaf;} - void SetMaxDepth(int32_t max_depth) {max_depth_ = max_depth;} + double GetAlpha() { return alpha_; } + double GetBeta() { return beta_; } + int32_t GetMinSamplesLeaf() { return min_samples_in_leaf_; } + int32_t GetMaxDepth() { return max_depth_; } + void SetAlpha(double alpha) { alpha_ = alpha; } + void SetBeta(double beta) { beta_ = beta; } + void SetMinSamplesLeaf(int32_t min_samples_in_leaf) { min_samples_in_leaf_ = min_samples_in_leaf; } + void SetMaxDepth(int32_t max_depth) { max_depth_ = max_depth; } + private: double alpha_; double beta_; @@ -71,15 +73,16 @@ class IGVariancePrior { scale_ = scale; } ~IGVariancePrior() {} - double GetShape() {return shape_;} - double GetScale() {return scale_;} - void SetShape(double shape) {shape_ = shape;} - void SetScale(double scale) {scale_ = scale;} + double GetShape() { return shape_; } + double GetScale() { return scale_; } + void SetShape(double shape) { shape_ = shape; } + void SetScale(double scale) { scale_ = scale; } + private: double shape_; double scale_; }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_PRIOR_H_ \ No newline at end of file +#endif // STOCHTREE_PRIOR_H_ \ No newline at end of file diff --git a/include/stochtree/random.h b/include/stochtree/random.h index 3d39b647..6f5931e1 100644 --- a/include/stochtree/random.h +++ b/include/stochtree/random.h @@ -12,13 +12,13 @@ namespace StochTree { /*! -* \brief A wrapper for random generator -*/ + * \brief A wrapper for random generator + */ class Random { public: /*! - * \brief Constructor, with random seed - */ + * \brief Constructor, with random seed + */ Random() { std::random_device rd; auto genrator = std::mt19937(rd()); @@ -26,45 +26,45 @@ class Random { x = distribution(genrator); } /*! - * \brief Constructor, with specific seed - */ + * \brief Constructor, with specific seed + */ explicit Random(int seed) { x = seed; } /*! - * \brief Generate random integer, int16 range. `[0, 65536]` - * \param lower_bound lower bound - * \param upper_bound upper bound - * \return The random integer between [lower_bound, upper_bound) - */ + * \brief Generate random integer, int16 range. `[0, 65536]` + * \param lower_bound lower bound + * \param upper_bound upper bound + * \return The random integer between [lower_bound, upper_bound) + */ inline int NextShort(int lower_bound, int upper_bound) { return (RandInt16()) % (upper_bound - lower_bound) + lower_bound; } /*! - * \brief Generate random integer, int32 range - * \param lower_bound lower bound - * \param upper_bound upper bound - * \return The random integer between [lower_bound, upper_bound) - */ + * \brief Generate random integer, int32 range + * \param lower_bound lower bound + * \param upper_bound upper bound + * \return The random integer between [lower_bound, upper_bound) + */ inline int NextInt(int lower_bound, int upper_bound) { return (RandInt32()) % (upper_bound - lower_bound) + lower_bound; } /*! - * \brief Generate random float data - * \return The random float between `[0.0, 1.0)` - */ + * \brief Generate random float data + * \return The random float between `[0.0, 1.0)` + */ inline float NextFloat() { // get random float in `[0,1)` return static_cast(RandInt16()) / (32768.0f); } /*! - * \brief Sample K data from `{0,1,...,N-1}` - * \param N - * \param K - * \return K Ordered sampled data from `{0,1,...,N-1}` - */ + * \brief Sample K data from `{0,1,...,N-1}` + * \param N + * \param K + * \return K Ordered sampled data from `{0,1,...,N-1}` + */ inline std::vector Sample(int N, int K) { std::vector ret; ret.reserve(K); @@ -110,7 +110,6 @@ class Random { unsigned int x = 123456789; }; - } // namespace StochTree -#endif // STOCHTREE_RANDOM_H_ +#endif // STOCHTREE_RANDOM_H_ diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index b322a560..6d6a7127 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -36,29 +36,29 @@ class RandomEffectsTracker { public: RandomEffectsTracker(std::vector& group_indices); ~RandomEffectsTracker() {} - inline data_size_t GetCategoryId(int observation_num) {return sample_category_mapper_->GetCategoryId(observation_num);} - inline data_size_t CategoryBegin(int category_id) {return category_sample_tracker_->CategoryBegin(category_id);} - inline data_size_t CategoryEnd(int category_id) {return category_sample_tracker_->CategoryEnd(category_id);} - inline data_size_t CategorySize(int category_id) {return category_sample_tracker_->CategorySize(category_id);} - inline int32_t NumCategories() {return num_categories_;} - inline int32_t CategoryNumber(int32_t category_id) {return category_sample_tracker_->CategoryNumber(category_id);} - SampleCategoryMapper* GetSampleCategoryMapper() {return sample_category_mapper_.get();} - CategorySampleTracker* GetCategorySampleTracker() {return category_sample_tracker_.get();} + inline data_size_t GetCategoryId(int observation_num) { return sample_category_mapper_->GetCategoryId(observation_num); } + inline data_size_t CategoryBegin(int category_id) { return category_sample_tracker_->CategoryBegin(category_id); } + inline data_size_t CategoryEnd(int category_id) { return category_sample_tracker_->CategoryEnd(category_id); } + inline data_size_t CategorySize(int category_id) { return category_sample_tracker_->CategorySize(category_id); } + inline int32_t NumCategories() { return num_categories_; } + inline int32_t CategoryNumber(int32_t category_id) { return category_sample_tracker_->CategoryNumber(category_id); } + SampleCategoryMapper* GetSampleCategoryMapper() { return sample_category_mapper_.get(); } + CategorySampleTracker* GetCategorySampleTracker() { return category_sample_tracker_.get(); } std::vector::iterator UnsortedNodeBeginIterator(int category_id); std::vector::iterator UnsortedNodeEndIterator(int category_id); - std::map& GetLabelMap() {return category_sample_tracker_->GetLabelMap();} - std::vector& GetUniqueGroupIds() {return category_sample_tracker_->GetUniqueGroupIds();} - std::vector& NodeIndices(int category_id) {return category_sample_tracker_->NodeIndices(category_id);} - std::vector& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);} - double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);} - void SetPrediction(data_size_t observation_num, double pred) {rfx_predictions_.at(observation_num) = pred;} + std::map& GetLabelMap() { return category_sample_tracker_->GetLabelMap(); } + std::vector& GetUniqueGroupIds() { return category_sample_tracker_->GetUniqueGroupIds(); } + std::vector& NodeIndices(int category_id) { return category_sample_tracker_->NodeIndices(category_id); } + std::vector& NodeIndicesInternalIndex(int internal_category_id) { return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id); } + double GetPrediction(data_size_t observation_num) { return rfx_predictions_.at(observation_num); } + void SetPrediction(data_size_t observation_num, double pred) { rfx_predictions_.at(observation_num) = pred; } /*! \brief Resets RFX tracker based on a specific sample. Assumes tracker already exists in main memory. */ - void ResetFromSample(MultivariateRegressionRandomEffectsModel& rfx_model, + void ResetFromSample(MultivariateRegressionRandomEffectsModel& rfx_model, RandomEffectsDataset& rfx_dataset, ColumnVector& residual); - /*! \brief Resets RFX tracker to initial default. Assumes tracker already exists in main memory. + /*! \brief Resets RFX tracker to initial default. Assumes tracker already exists in main memory. * Assumes that the initial "clean slate" prediction of a random effects model is 0. */ - void RootReset(MultivariateRegressionRandomEffectsModel& rfx_model, + void RootReset(MultivariateRegressionRandomEffectsModel& rfx_model, RandomEffectsDataset& rfx_dataset, ColumnVector& residual); private: @@ -113,11 +113,15 @@ class LabelMapper { this->Reset(); this->from_json(rfx_label_mapper_json); } - std::vector& Keys() {return keys_;} - std::map& Map() {return label_map_;} - void Reset() {label_map_.clear(); keys_.clear();} + std::vector& Keys() { return keys_; } + std::map& Map() { return label_map_; } + void Reset() { + label_map_.clear(); + keys_.clear(); + } nlohmann::json to_json(); void from_json(const nlohmann::json& rfx_label_mapper_json); + private: std::map label_map_; std::vector keys_; @@ -140,7 +144,7 @@ class MultivariateRegressionRandomEffectsModel { /*! \brief Reconstruction from serialized model parameter samples */ void ResetFromSample(RandomEffectsContainer& rfx_container, int sample_num); - + /*! \brief Samplers */ void SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen); void SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen); @@ -192,9 +196,9 @@ class MultivariateRegressionRandomEffectsModel { double GetVariancePriorScale() { return variance_prior_scale_; } - inline int32_t NumComponents() {return num_components_;} - inline int32_t NumGroups() {return num_groups_;} - + inline int32_t NumComponents() { return num_components_; } + inline int32_t NumGroups() { return num_groups_; } + std::vector Predict(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker) { std::vector output(dataset.NumObservations()); PredictInplace(dataset, tracker, output); @@ -266,7 +270,7 @@ class MultivariateRegressionRandomEffectsModel { /*! \brief Random effects structure details */ int num_components_; int num_groups_; - + /*! \brief Group mean parameters, decomposed into "working parameter" and individual parameters * under the "redundant" parameterization of Gelman et al (2008) */ @@ -275,7 +279,7 @@ class MultivariateRegressionRandomEffectsModel { /*! \brief Variance components for the group parameters */ Eigen::MatrixXd group_parameter_covariance_; - + /*! \brief Variance components for the working parameter */ Eigen::MatrixXd working_parameter_covariance_; @@ -320,12 +324,12 @@ class RandomEffectsContainer { void AddSample(MultivariateRegressionRandomEffectsModel& model); void DeleteSample(int sample_num); void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector& output); - inline int NumSamples() {return num_samples_;} - inline int NumComponents() {return num_components_;} - inline int NumGroups() {return num_groups_;} - inline void SetNumSamples(int num_samples) {num_samples_ = num_samples;} - inline void SetNumComponents(int num_components) {num_components_ = num_components;} - inline void SetNumGroups(int num_groups) {num_groups_ = num_groups;} + inline int NumSamples() { return num_samples_; } + inline int NumComponents() { return num_components_; } + inline int NumGroups() { return num_groups_; } + inline void SetNumSamples(int num_samples) { num_samples_ = num_samples; } + inline void SetNumComponents(int num_components) { num_components_ = num_components; } + inline void SetNumGroups(int num_groups) { num_groups_ = num_groups; } void Reset() { num_samples_ = 0; num_components_ = 0; @@ -335,13 +339,14 @@ class RandomEffectsContainer { xi_.clear(); sigma_xi_.clear(); } - std::vector& GetBeta() {return beta_;} - std::vector& GetAlpha() {return alpha_;} - std::vector& GetXi() {return xi_;} - std::vector& GetSigma() {return sigma_xi_;} + std::vector& GetBeta() { return beta_; } + std::vector& GetAlpha() { return alpha_; } + std::vector& GetXi() { return xi_; } + std::vector& GetSigma() { return sigma_xi_; } nlohmann::json to_json(); void from_json(const nlohmann::json& rfx_container_json); void append_from_json(const nlohmann::json& rfx_container_json); + private: int num_samples_; int num_components_; @@ -355,6 +360,6 @@ class RandomEffectsContainer { void AddSigma(MultivariateRegressionRandomEffectsModel& model); }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_RANDOM_EFFECTS_H_ +#endif // STOCHTREE_RANDOM_EFFECTS_H_ diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index 3810e3cb..61bf6005 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -52,13 +52,13 @@ enum FeatureSplitType { /*! \brief Forward declaration of TreeSplit class */ class TreeSplit; -/*! +/*! * \defgroup tree_group Tree API - * + * * \brief Classes / functions for creating and modifying decision trees. - * + * * \section tree_design Design - * + * * \{ */ @@ -68,7 +68,7 @@ class Tree { static constexpr std::int32_t kInvalidNodeId{-1}; static constexpr std::int32_t kDeletedNodeMarker = std::numeric_limits::max(); static constexpr std::int32_t kRoot{0}; - + Tree() = default; // ~Tree() = default; Tree(Tree const&) = delete; @@ -76,9 +76,9 @@ class Tree { Tree(Tree&&) noexcept = default; Tree& operator=(Tree&&) noexcept = default; /*! - * \brief Copy the structure and parameters of another tree. If the `Tree` object calling this method already + * \brief Copy the structure and parameters of another tree. If the `Tree` object calling this method already * has a non-root tree structure / parameters, this will be erased and replaced with a copy of `tree`. - * + * * \param tree Tree to be cloned */ void CloneFromTree(Tree* tree); @@ -102,19 +102,19 @@ class Tree { void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector left_value_vector, std::vector right_value_vector); /*! \brief Expand a node based on a categorical split rule */ void ExpandNode(std::int32_t nid, int split_index, std::vector const& categorical_indices, std::vector left_value_vector, std::vector right_value_vector); - /*! \brief Expand a node based on a generic split rule */ + /*! \brief Expand a node based on a generic split rule */ void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, double left_value, double right_value); /*! \brief Expand a node based on a generic split rule */ void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, std::vector left_value_vector, std::vector right_value_vector); /*! \brief Whether or not a tree is a "stump" consisting of a single root node */ - inline bool IsRoot() {return leaves_.size() == 1;} - + inline bool IsRoot() { return leaves_.size() == 1; } + /*! \brief Convert tree to JSON and return JSON in-memory */ json to_json(); - /*! - * \brief Load from JSON - * + /*! + * \brief Load from JSON + * * \param tree_json In-memory json object (of type `nlohmann::json`) */ void from_json(const json& tree_json); @@ -135,7 +135,7 @@ class Tree { // TODO refactor and add this to the multivariate case as well if (!IsRoot(nid)) { int parent_id = Parent(nid); - if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))){ + if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))) { leaf_parents_.push_back(parent_id); } } @@ -174,12 +174,12 @@ class Tree { // TODO refactor and add this to the multivariate case as well if (!IsRoot(nid)) { int parent_id = Parent(nid); - if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))){ + if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))) { leaf_parents_.push_back(parent_id); } } } - + /*! * \brief Collapse an internal node to a leaf node, deleting its children from the tree * \param nid Node id of the new leaf node @@ -200,7 +200,7 @@ class Tree { /*! * \brief Add a constant value to every leaf of a tree. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. - * + * * \param constant_value Value that will be added to every leaf of a tree */ void AddValueToLeaves(double constant_value) { @@ -217,7 +217,7 @@ class Tree { /*! * \brief Multiply every leaf of a tree by a constant value. If leaves are multi-dimensional, `constant_value` will be multiplied through every dimension of the leaves. - * + * * \param constant_multiple Value that will be multiplied by every leaf of a tree */ void MultiplyLeavesByValue(double constant_multiple) { @@ -234,14 +234,15 @@ class Tree { /*! * \brief Iterate through all nodes in this tree. - * + * * \tparam Func Function object type, must map `std::int32_t` to `bool`. * \param func Function that accepts a node index and returns `False` when iteration through a given branch of the tree should stop and `True` otherwise. */ - template void WalkTree(Func func) const { + template + void WalkTree(Func func) const { std::stack nodes; nodes.push(kRoot); - auto &self = *this; + auto& self = *this; while (!nodes.empty()) { auto nidx = nodes.top(); nodes.pop(); @@ -285,7 +286,7 @@ class Tree { bool IsLogScale() const { return is_log_scale_; } - + /*! * \brief Index of the node's parent * \param nid ID of node being queried @@ -293,7 +294,7 @@ class Tree { std::int32_t Parent(std::int32_t nid) const { return parent_[nid]; } - + /*! * \brief Index of the node's left child * \param nid ID of node being queried @@ -301,7 +302,7 @@ class Tree { std::int32_t LeftChild(std::int32_t nid) const { return cleft_[nid]; } - + /*! * \brief Index of the node's right child * \param nid ID of node being queried @@ -309,7 +310,7 @@ class Tree { std::int32_t RightChild(std::int32_t nid) const { return cright_[nid]; } - + /*! * \brief Index of the node's "default" child (potentially used in the case of a missing feature at prediction time) * \param nid ID of node being queried @@ -317,7 +318,7 @@ class Tree { std::int32_t DefaultChild(std::int32_t nid) const { return cleft_[nid]; } - + /*! * \brief Feature index defining the node's split rule * \param nid ID of node being queried @@ -325,7 +326,7 @@ class Tree { std::int32_t SplitIndex(std::int32_t nid) const { return split_index_[nid]; } - + /*! * \brief Whether the node is a leaf node * \param nid ID of node being queried @@ -333,7 +334,7 @@ class Tree { bool IsLeaf(std::int32_t nid) const { return cleft_[nid] == kInvalidNodeId; } - + /*! * \brief Whether the node is root * \param nid ID of node being queried @@ -357,7 +358,7 @@ class Tree { double LeafValue(std::int32_t nid) const { return leaf_value_[nid]; } - + /*! * \brief Get parameter value of a node (typically though not necessarily a leaf node) at a given output dimension * \param nid ID of node being queried @@ -386,7 +387,7 @@ class Tree { std::stack node_depths; nodes.push(kRoot); node_depths.push(0); - auto &self = *this; + auto& self = *this; while (!nodes.empty()) { auto nidx = nodes.top(); nodes.pop(); @@ -399,11 +400,11 @@ class Tree { auto right = self.RightChild(nidx); if (left != Tree::kInvalidNodeId) { nodes.push(left); - node_depths.push(node_depth+1); + node_depths.push(node_depth + 1); } if (right != Tree::kInvalidNodeId) { nodes.push(right); - node_depths.push(node_depth+1); + node_depths.push(node_depth + 1); } } } @@ -457,7 +458,7 @@ class Tree { } return result; } - + /*! * \brief Tests whether the leaf node has a non-empty leaf vector * \param nid ID of node being queried @@ -476,7 +477,7 @@ class Tree { /*! * \brief Get list of all categories belonging to the left child node. - * Categories are integers ranging from 0 to (n-1), where n is the number of categories in that particular feature. + * Categories are integers ranging from 0 to (n-1), where n is the number of categories in that particular feature. * This list is assumed to be in ascending order. * * \param nid ID of node being queried @@ -538,7 +539,7 @@ class Tree { bool is_right_leaf = false; // Check if node nidx is a leaf, if so, return false bool is_leaf = this->IsLeaf(nid); - if (is_leaf){ + if (is_leaf) { return false; } else { // If nidx is not a leaf, it must have left and right nodes @@ -571,7 +572,7 @@ class Tree { [[nodiscard]] std::vector const& GetLeafParents() const { return leaf_parents_; } - + /*! * \brief Get indices of all valid (non-deleted) nodes. */ @@ -579,11 +580,11 @@ class Tree { std::vector output; auto const& self = *this; this->WalkTree([&output, &self](std::int32_t nidx) { - if (!self.IsDeleted(nidx)) { - output.push_back(nidx); - } - return true; - }); + if (!self.IsDeleted(nidx)) { + output.push_back(nidx); + } + return true; + }); return output; } @@ -604,12 +605,12 @@ class Tree { * \brief Get the total number of nodes including deleted ones in this tree. */ [[nodiscard]] std::int32_t NumNodes() const noexcept { return num_nodes; } - + /** * \brief Get the total number of deleted nodes in this tree. */ [[nodiscard]] std::int32_t NumDeletedNodes() const noexcept { return num_deleted_nodes; } - + /** * \brief Get the total number of valid nodes in this tree. */ @@ -675,7 +676,7 @@ class Tree { */ void SetNumericSplit( std::int32_t nid, std::int32_t split_index, double threshold); - + /*! * \brief Create a categorical split * \param nid ID of node being updated @@ -685,8 +686,8 @@ class Tree { * which node the category list should represent. */ void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index, - std::vector const& category_list); - + std::vector const& category_list); + /*! * \brief Set the leaf value of the node * \param nid ID of node being updated @@ -703,18 +704,18 @@ class Tree { /*! * \brief Obtain a 0-based leaf index for each observation in a ForestDataset. - * Internally, trees are stored as vectors of node information, + * Internally, trees are stored as vectors of node information, * and the `leaves_` vector gives us node IDs for every leaf in the tree. - * Here, we would like to know, for every observation in a dataset, - * which leaf number it is mapped to. Since the leaf numbers themselves - * do not carry any information, we renumber them from `0` to `leaves_.size()-1`. + * Here, we would like to know, for every observation in a dataset, + * which leaf number it is mapped to. Since the leaf numbers themselves + * do not carry any information, we renumber them from `0` to `leaves_.size()-1`. * - * Note: this is a tree-level helper function for an ensemble-level function. - * It assumes the creation of: + * Note: this is a tree-level helper function for an ensemble-level function. + * It assumes the creation of: * -# a vector of column indices of size `dataset.NumObservations()` x `ensemble.NumTrees()`, stored in "tree-major" order - * -# a running counter of the number of tree-observations already indexed in the ensemble + * -# a running counter of the number of tree-observations already indexed in the ensemble * (used as offsets for the leaf number computed and returned here) - * Users running this function for a single tree may simply pre-allocate an output vector as + * Users running this function for a single tree may simply pre-allocate an output vector as * `std::vector output(dataset->NumObservations())` and set the offset to 0. * \param dataset Dataset with which to predict leaf indices from the tree * \param output Pre-allocated output vector storing a matrix of column indices, with "rows" corresponding to observations in `dataset` and "columns" corresponding to trees in an ensemble @@ -725,18 +726,18 @@ class Tree { /*! * \brief Obtain a 0-based leaf index for each observation in a ForestDataset. - * Internally, trees are stored as vectors of node information, + * Internally, trees are stored as vectors of node information, * and the `leaves_` vector gives us node IDs for every leaf in the tree. - * Here, we would like to know, for every observation in a dataset, - * which leaf number it is mapped to. Since the leaf numbers themselves - * do not carry any information, we renumber them from `0` to `leaves_.size()-1`. + * Here, we would like to know, for every observation in a dataset, + * which leaf number it is mapped to. Since the leaf numbers themselves + * do not carry any information, we renumber them from `0` to `leaves_.size()-1`. * - * Note: this is a tree-level helper function for an ensemble-level function. - * It assumes the creation of: + * Note: this is a tree-level helper function for an ensemble-level function. + * It assumes the creation of: * -# a vector of column indices of size `dataset.NumObservations()` x `ensemble.NumTrees()`, stored in "tree-major" order - * -# a running counter of the number of tree-observations already indexed in the ensemble + * -# a running counter of the number of tree-observations already indexed in the ensemble * (used as offsets for the leaf number computed and returned here) - * Users running this function for a single tree may simply pre-allocate an output vector as + * Users running this function for a single tree may simply pre-allocate an output vector as * `std::vector output(dataset->NumObservations())` and set the offset to 0. * \param covariates Eigen matrix with which to predict leaf indices * \param output Pre-allocated output vector storing a matrix of column indices, with "rows" corresponding to observations in `covariates` and "columns" corresponding to trees in an ensemble @@ -747,18 +748,18 @@ class Tree { /*! * \brief Obtain a 0-based leaf index for each observation in a ForestDataset. - * Internally, trees are stored as vectors of node information, + * Internally, trees are stored as vectors of node information, * and the `leaves_` vector gives us node IDs for every leaf in the tree. - * Here, we would like to know, for every observation in a dataset, - * which leaf number it is mapped to. Since the leaf numbers themselves - * do not carry any information, we renumber them from `0` to `leaves_.size()-1`. + * Here, we would like to know, for every observation in a dataset, + * which leaf number it is mapped to. Since the leaf numbers themselves + * do not carry any information, we renumber them from `0` to `leaves_.size()-1`. * - * Note: this is a tree-level helper function for an ensemble-level function. - * It assumes the creation of: + * Note: this is a tree-level helper function for an ensemble-level function. + * It assumes the creation of: * -# a vector of column indices of size `dataset.NumObservations()` x `ensemble.NumTrees()`, stored in "tree-major" order - * -# a running counter of the number of tree-observations already indexed in the ensemble + * -# a running counter of the number of tree-observations already indexed in the ensemble * (used as offsets for the leaf number computed and returned here) - * Users running this function for a single tree may simply pre-allocate an output vector as + * Users running this function for a single tree may simply pre-allocate an output vector as * `std::vector output(dataset->NumObservations())` and set the offset to 0. * \param covariates Eigen matrix with which to predict leaf indices * \param output Pre-allocated output vector storing a matrix of column indices, with "rows" corresponding to observations in `covariates` and "columns" corresponding to trees in an ensemble @@ -767,8 +768,8 @@ class Tree { */ void PredictLeafIndexInplace(Eigen::Map>& covariates, std::vector& output, int32_t offset, int32_t max_leaf); - void PredictLeafIndexInplace(Eigen::Map>& covariates, - Eigen::Map>& output, + void PredictLeafIndexInplace(Eigen::Map>& covariates, + Eigen::Map>& output, int column_ind, int32_t offset, int32_t max_leaf); // Node info @@ -784,7 +785,7 @@ class Tree { std::vector leaves_; std::vector leaf_parents_; std::vector deleted_nodes_; - + // Leaf vector std::vector leaf_vector_; std::vector leaf_vector_begin_; @@ -803,27 +804,26 @@ class Tree { /*! \brief Comparison operator for trees */ inline bool operator==(const Tree& lhs, const Tree& rhs) { return ( - (lhs.has_categorical_split_ == rhs.has_categorical_split_) && - (lhs.output_dimension_ == rhs.output_dimension_) && - (lhs.is_log_scale_ == rhs.is_log_scale_) && - (lhs.node_type_ == rhs.node_type_) && - (lhs.parent_ == rhs.parent_) && - (lhs.cleft_ == rhs.cleft_) && - (lhs.cright_ == rhs.cright_) && - (lhs.split_index_ == rhs.split_index_) && - (lhs.leaf_value_ == rhs.leaf_value_) && - (lhs.threshold_ == rhs.threshold_) && - (lhs.internal_nodes_ == rhs.internal_nodes_) && - (lhs.leaves_ == rhs.leaves_) && - (lhs.leaf_parents_ == rhs.leaf_parents_) && - (lhs.deleted_nodes_ == rhs.deleted_nodes_) && - (lhs.leaf_vector_ == rhs.leaf_vector_) && - (lhs.leaf_vector_begin_ == rhs.leaf_vector_begin_) && - (lhs.leaf_vector_end_ == rhs.leaf_vector_end_) && - (lhs.category_list_ == rhs.category_list_) && - (lhs.category_list_begin_ == rhs.category_list_begin_) && - (lhs.category_list_end_ == rhs.category_list_end_) - ); + (lhs.has_categorical_split_ == rhs.has_categorical_split_) && + (lhs.output_dimension_ == rhs.output_dimension_) && + (lhs.is_log_scale_ == rhs.is_log_scale_) && + (lhs.node_type_ == rhs.node_type_) && + (lhs.parent_ == rhs.parent_) && + (lhs.cleft_ == rhs.cleft_) && + (lhs.cright_ == rhs.cright_) && + (lhs.split_index_ == rhs.split_index_) && + (lhs.leaf_value_ == rhs.leaf_value_) && + (lhs.threshold_ == rhs.threshold_) && + (lhs.internal_nodes_ == rhs.internal_nodes_) && + (lhs.leaves_ == rhs.leaves_) && + (lhs.leaf_parents_ == rhs.leaf_parents_) && + (lhs.deleted_nodes_ == rhs.deleted_nodes_) && + (lhs.leaf_vector_ == rhs.leaf_vector_) && + (lhs.leaf_vector_begin_ == rhs.leaf_vector_begin_) && + (lhs.leaf_vector_end_ == rhs.leaf_vector_end_) && + (lhs.category_list_ == rhs.category_list_) && + (lhs.category_list_begin_ == rhs.category_list_begin_) && + (lhs.category_list_end_ == rhs.category_list_end_)); } /*! \brief Determine whether an observation produces a "true" value in a numeric split node @@ -847,15 +847,13 @@ inline bool SplitTrueCategorical(double fvalue, std::vector const // A valid (integer) category must satisfy two criteria: // 1) it must be exactly representable as double // 2) it must fit into uint32_t - auto max_representable_int - = std::min(static_cast(std::numeric_limits::max()), - static_cast(std::uint64_t(1) << std::numeric_limits::digits)); + auto max_representable_int = std::min(static_cast(std::numeric_limits::max()), + static_cast(std::uint64_t(1) << std::numeric_limits::digits)); if (fvalue < 0 || std::fabs(fvalue) > max_representable_int) { category_matched = false; } else { auto const category_value = static_cast(fvalue); - category_matched = (std::find(category_list.begin(), category_list.end(), category_value) - != category_list.end()); + category_matched = (std::find(category_list.begin(), category_list.end(), category_value) != category_list.end()); } return category_matched; } @@ -880,9 +878,9 @@ inline int NextNodeCategorical(double fvalue, std::vector const& return SplitTrueCategorical(fvalue, category_list) ? left_child : right_child; } -/*! +/*! * Determine the node at which a tree places a given observation - * + * * \param tree Tree object used for prediction * \param data Dataset used for prediction * \param row Row indexing the prediction observation @@ -897,7 +895,7 @@ inline int EvaluateTree(Tree const& tree, Eigen::MatrixXd& data, int row) { } else { if (tree.NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) { node_id = NextNodeCategorical(fvalue, tree.CategoryList(node_id), - tree.LeftChild(node_id), tree.RightChild(node_id)); + tree.LeftChild(node_id), tree.RightChild(node_id)); } else { node_id = NextNodeNumeric(fvalue, tree.Threshold(node_id), tree.LeftChild(node_id), tree.RightChild(node_id)); } @@ -906,9 +904,9 @@ inline int EvaluateTree(Tree const& tree, Eigen::MatrixXd& data, int row) { return node_id; } -/*! +/*! * Determine the node at which a tree places a given observation - * + * * \param tree Tree object used for prediction * \param data Dataset used for prediction * \param row Row indexing the prediction observation @@ -923,7 +921,7 @@ inline int EvaluateTree(Tree const& tree, Eigen::Map& split_categories) { @@ -979,22 +977,25 @@ class TreeSplit { split_set_ = true; } ~TreeSplit() {} - bool SplitSet() {return split_set_;} + bool SplitSet() { return split_set_; } /*! \brief Whether or not a `TreeSplit` rule is numeric */ - bool NumericSplit() {return numeric_;} + bool NumericSplit() { return numeric_; } /*! * \brief Whether a given covariate value is `True` or `False` on the rule defined by a `TreeSplit` object - * + * * \param fvalue Value of the covariate */ bool SplitTrue(double fvalue) { - if (numeric_) return SplitTrueNumeric(fvalue, split_value_); - else return SplitTrueCategorical(fvalue, split_categories_); + if (numeric_) + return SplitTrueNumeric(fvalue, split_value_); + else + return SplitTrueCategorical(fvalue, split_categories_); } /*! \brief Numeric cutoff value defining a `TreeSplit` object */ - double SplitValue() {return split_value_;} + double SplitValue() { return split_value_; } /*! \brief Categories defining a `TreeSplit` object */ - std::vector SplitCategories() {return split_categories_;} + std::vector SplitCategories() { return split_categories_; } + private: bool split_set_{false}; bool numeric_; @@ -1002,8 +1003,8 @@ class TreeSplit { std::vector split_categories_; }; -/*! \} */ // end of tree_group +/*! \} */ // end of tree_group -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_TREE_H_ +#endif // STOCHTREE_TREE_H_ diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 8a12f81f..0092ed2d 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -23,12 +23,12 @@ namespace StochTree { /*! * \defgroup sampling_group Forest Sampler API * - * \brief Functions for sampling from a forest. The core interface of these functions, - * as used by the R, Python, and standalone C++ program, is defined by - * \ref MCMCSampleOneIter, which runs one iteration of the MCMC sampler for a - * given forest, and \ref GFRSampleOneIter, which runs one iteration of the - * grow-from-root (GFR) algorithm for a given forest. All other functions are - * essentially helpers used in a sampling function, which are documented here + * \brief Functions for sampling from a forest. The core interface of these functions, + * as used by the R, Python, and standalone C++ program, is defined by + * \ref MCMCSampleOneIter, which runs one iteration of the MCMC sampler for a + * given forest, and \ref GFRSampleOneIter, which runs one iteration of the + * grow-from-root (GFR) algorithm for a given forest. All other functions are + * essentially helpers used in a sampling function, which are documented here * to make extending the C++ codebase more straightforward. * * \{ @@ -36,7 +36,7 @@ namespace StochTree { /*! * \brief Computer the range of available split values for a continuous variable, given the current structure of a tree. - * + * * \param tracker Tracking data structures that speed up sampler operations. * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. * \param tree_num Index of the tree for which a split is proposed. @@ -49,10 +49,10 @@ static inline void VarSplitRange(ForestTracker& tracker, ForestDataset& dataset, var_min = std::numeric_limits::max(); var_max = std::numeric_limits::min(); double feature_value; - + std::vector::iterator node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split); std::vector::iterator node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split); - + for (auto i = node_begin_iter; i != node_end_iter; i++) { auto idx = *i; feature_value = dataset.CovariateValue(idx, feature_split); @@ -66,14 +66,14 @@ static inline void VarSplitRange(ForestTracker& tracker, ForestDataset& dataset, /*! * \brief Determines whether a proposed split creates two leaf nodes with constant values for every feature (thus ensuring that the tree cannot split further). - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. * \param tracker Tracking data structures that speed up sampler operations. * \param split Proposed split of tree `tree_num` at node `leaf_split`. * \param tree_num Index of the tree for which a split is proposed. * \param leaf_split Index of the leaf in `tree_num` for which a split is proposed. * \param feature_split Index of the feature to which `split` will be applied - * \return `true` if `split` creates two nodes with constant values for every feature in `dataset`, `false` otherwise. + * \return `true` if `split` creates two nodes with constant values for every feature in `dataset`, `false` otherwise. */ static inline bool NodesNonConstantAfterSplit(ForestDataset& dataset, ForestTracker& tracker, TreeSplit& split, int tree_num, int leaf_split, int feature_split) { int p = dataset.GetCovariates().cols(); @@ -84,7 +84,7 @@ static inline bool NodesNonConstantAfterSplit(ForestDataset& dataset, ForestTrac double var_min_left; double var_max_right; double var_min_right; - + for (int j = 0; j < p; j++) { auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split); auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split); @@ -124,7 +124,7 @@ static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracke double feature_value; double var_max; double var_min; - + for (int j = 0; j < p; j++) { auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id); auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id); @@ -147,7 +147,7 @@ static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracke return false; } -static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree, +static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree, int tree_num, int leaf_node, int feature_split, bool keep_sorted = false, int num_threads = -1) { // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete if (tree->OutputDimension() > 1) { @@ -164,7 +164,7 @@ static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& datase tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted, num_threads); } -static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree, +static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree, int tree_num, int leaf_node, int left_node, int right_node, bool keep_sorted = false) { // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete if (tree->OutputDimension() > 1) { @@ -203,7 +203,7 @@ static inline double ComputeVarianceOutcome(ColumnVector& residual) { return sum_y_sq / static_cast(n) - (sum_y * sum_y) / (static_cast(n) * static_cast(n)); } -static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, +static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, bool requires_basis, std::function op) { data_size_t n = dataset.GetCovariates().rows(); double tree_pred = 0.; @@ -222,7 +222,7 @@ static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDatas tracker.SetTreeSamplePrediction(i, j, tree_pred); pred_value += tree_pred; } - + // Run op (either plus or minus) on the residual and the new prediction new_resid = op(residual.GetElement(i), pred_value); residual.SetElement(i, new_resid); @@ -230,7 +230,7 @@ static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDatas tracker.SyncPredictions(); } -static inline void UpdateResidualNoTrackerUpdate(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, +static inline void UpdateResidualNoTrackerUpdate(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, bool requires_basis, std::function op) { data_size_t n = dataset.GetCovariates().rows(); double tree_pred = 0.; @@ -248,14 +248,14 @@ static inline void UpdateResidualNoTrackerUpdate(ForestTracker& tracker, ForestD } pred_value += tree_pred; } - + // Run op (either plus or minus) on the residual and the new prediction new_resid = op(residual.GetElement(i), pred_value); residual.SetElement(i, new_resid); } } -static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, +static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, bool requires_basis, std::function op) { data_size_t n = dataset.GetCovariates().rows(); double tree_pred = 0.; @@ -274,7 +274,7 @@ static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestData tracker.SetTreeSamplePrediction(i, j, tree_pred); pred_value += tree_pred; } - + // Run op (either plus or minus) on the residual and the new prediction new_resid = op(residual.GetElement(i), pred_value); residual.SetElement(i, new_resid); @@ -296,8 +296,8 @@ static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector } } -static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, - bool requires_basis, std::function op, bool tree_new) { +static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, + bool requires_basis, std::function op, bool tree_new) { data_size_t n = dataset.GetCovariates().rows(); double pred_value; int32_t leaf_pred; @@ -305,7 +305,7 @@ static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& da double pred_delta; for (data_size_t i = 0; i < n; i++) { if (tree_new) { - // If the tree has been newly sampled or adjusted, we must rerun the prediction + // If the tree has been newly sampled or adjusted, we must rerun the prediction // method and update the SamplePredMapper stored in tracker leaf_pred = tracker.GetNodeId(i, tree_num); if (requires_basis) { @@ -317,7 +317,7 @@ static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& da tracker.SetTreeSamplePrediction(i, tree_num, pred_value); tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta); } else { - // If the tree has not yet been modified via a sampling step, + // If the tree has not yet been modified via a sampling step, // we can query its prediction directly from the SamplePredMapper stored in tracker pred_value = tracker.GetTreeSamplePrediction(i, tree_num); } @@ -345,13 +345,13 @@ static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& // Compute new prediction based on updated basis leaf_pred = tracker.GetNodeId(i, tree_num); new_tree_pred = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i); - + // Cache the new prediction in the tracker tracker.SetTreeSamplePrediction(i, tree_num, new_tree_pred); // Subtract out the updated tree prediction new_resid -= new_tree_pred; - + // Propagate the change back to the residual residual.SetElement(i, new_resid); } @@ -359,7 +359,7 @@ static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& tracker.SyncPredictions(); } -static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, +static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function op, bool tree_new) { data_size_t n = dataset.GetCovariates().rows(); double pred_value; @@ -370,7 +370,7 @@ static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dat double prev_pred; for (data_size_t i = 0; i < n; i++) { if (tree_new) { - // If the tree has been newly sampled or adjusted, we must rerun the prediction + // If the tree has been newly sampled or adjusted, we must rerun the prediction // method and update the SamplePredMapper stored in tracker leaf_pred = tracker.GetNodeId(i, tree_num); if (requires_basis) { @@ -386,7 +386,7 @@ static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dat new_weight = std::log(dataset.VarWeightValue(i)) + pred_value; dataset.SetVarWeightValue(i, new_weight, true); } else { - // If the tree has not yet been modified via a sampling step, + // If the tree has not yet been modified via a sampling step, // we can query its prediction directly from the SamplePredMapper stored in tracker pred_value = tracker.GetTreeSamplePrediction(i, tree_num); new_weight = std::log(dataset.VarWeightValue(i)) - pred_value; @@ -395,8 +395,8 @@ static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dat } } -static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, - bool requires_basis, bool tree_new) { +static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, + bool requires_basis, bool tree_new) { data_size_t n = dataset.GetCovariates().rows(); double pred_value; @@ -404,7 +404,7 @@ static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& double pred_delta; for (data_size_t i = 0; i < n; i++) { if (tree_new) { - // If the tree has been newly sampled or adjusted, we must rerun the prediction + // If the tree has been newly sampled or adjusted, we must rerun the prediction // method and update the SamplePredMapper stored in tracker leaf_pred = tracker.GetNodeId(i, tree_num); if (requires_basis) { @@ -418,7 +418,7 @@ static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num) dataset.SetAuxiliaryDataValue(1, i, tracker.GetSamplePrediction(i) - pred_value); } else { - // If the tree has not yet been modified via a sampling step, + // If the tree has not yet been modified via a sampling step, // we can query its prediction directly from the SamplePredMapper stored in tracker pred_value = tracker.GetTreeSamplePrediction(i, tree_num); // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num): needed? since tree not changed? @@ -431,10 +431,9 @@ static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& template static inline std::tuple EvaluateProposedSplit( - ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, - TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance, - int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args -) { + ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, + TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance, + int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Initialize sufficient statistics LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); @@ -442,10 +441,9 @@ static inline std::tuple EvaluatePropo // Accumulate sufficient statistics AccumulateSuffStatProposed( - node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, split, tree_num, leaf_num, split_feature, num_threads, - leaf_suff_stat_args... - ); + node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, + residual, global_variance, split, tree_num, leaf_num, split_feature, num_threads, + leaf_suff_stat_args...); data_size_t left_n = left_suff_stat.n; data_size_t right_n = right_suff_stat.n; @@ -458,17 +456,16 @@ static inline std::tuple EvaluatePropo template static inline std::tuple EvaluateExistingSplit( - ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, - double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id, - LeafSuffStatConstructorArgs&... leaf_suff_stat_args -) { + ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, + double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Initialize sufficient statistics LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); // Accumulate sufficient statistics - AccumulateSuffStatExisting(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, + AccumulateSuffStatExisting(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id); data_size_t left_n = left_suff_stat.n; data_size_t right_n = right_suff_stat.n; @@ -481,8 +478,8 @@ static inline std::tuple EvaluateExist } template -static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { +static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if constexpr (std::is_same_v) { UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false); } else if (backfitting) { @@ -494,8 +491,8 @@ static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafMod } template -static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { +static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if constexpr (std::is_same_v) { UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true); } else if (backfitting) { @@ -507,10 +504,10 @@ static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafMode } template -static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, - std::unordered_map>& node_index_map, std::deque& split_queue, - int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, +static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, + std::unordered_map>& node_index_map, std::deque& split_queue, + int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, std::vector& feature_types, std::vector feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Leaf depth int leaf_depth = tree->GetDepth(node_id); @@ -519,15 +516,14 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel int32_t max_depth = tree_prior.GetMaxDepth(); if ((max_depth == -1) || (leaf_depth < max_depth)) { - // Vector of vectors to store results for each feature int p = dataset.NumCovariates(); - std::vector> feature_log_cutpoint_evaluations(p+1); - std::vector> feature_cutpoint_values(p+1); - std::vector feature_cutpoint_counts(p+1, 0); + std::vector> feature_log_cutpoint_evaluations(p + 1); + std::vector> feature_cutpoint_values(p + 1); + std::vector feature_cutpoint_counts(p + 1, 0); StochTree::data_size_t valid_cutpoint_count; - - // Evaluate all possible cutpoints according to the leaf node model, + + // Evaluate all possible cutpoints according to the leaf node model, // recording their log-likelihood and other split information in a series of vectors. // Initialize node sufficient statistics @@ -545,7 +541,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel Eigen::VectorXd var_weights; bool has_weights = dataset.HasVarWeights(); if (has_weights) var_weights = dataset.GetVarWeights(); - + // Minimum size of newly created leaf nodes (used to rule out invalid splits) int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); @@ -557,13 +553,13 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel // Initialize cutpoint grid container CutpointGridContainer cutpoint_grid_container(covariates, outcome, cutpoint_grid_size); - + // Evaluate all possible splits for each feature in parallel StochTree::ParallelFor(0, covariates.cols(), num_threads, [&](int j) { if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) { // Enumerate cutpoint strides cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types); - + // Left and right node sufficient statistics LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); @@ -584,18 +580,18 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel // Compute the corresponding right node sufficient statistics right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); - // Store the bin index as the "cutpoint value" - we can use this to query the actual split + // Store the bin index as the "cutpoint value" - we can use this to query the actual split // value or the set of split categories later on once a split is chose double cutoff_value = cutpoint_idx; // Only include cutpoint for consideration if it defines a valid split in the training data - bool valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && + bool valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); if (valid_split) { feature_cutpoint_counts[j]++; // Add to split rule vector feature_cutpoint_values[j].push_back(cutoff_value); - // Add the log marginal likelihood of the split to the split eval vector + // Add the log marginal likelihood of the split to the split eval vector double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); feature_log_cutpoint_evaluations[j].push_back(split_log_ml); } @@ -608,29 +604,29 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) feature_log_cutpoint_evaluations[covariates.cols()].push_back(no_split_log_ml); - + // Compute an adjustment to reflect the no split prior probability and the number of cutpoints double bart_prior_no_split_adj; double alpha = tree_prior.GetAlpha(); double beta = tree_prior.GetBeta(); int node_depth = tree->GetDepth(node_id); if (valid_cutpoint_count == 0) { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); + bart_prior_no_split_adj = std::log(((std::pow(1 + node_depth, beta)) / alpha) - 1.0); } else { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); + bart_prior_no_split_adj = std::log(((std::pow(1 + node_depth, beta)) / alpha) - 1.0) + std::log(valid_cutpoint_count); } feature_log_cutpoint_evaluations[covariates.cols()][0] += bart_prior_no_split_adj; - // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood double largest_ml = -std::numeric_limits::infinity(); for (int j = 0; j < p + 1; j++) { if (feature_log_cutpoint_evaluations[j].size() > 0) { - double feature_max_ml = *std::max_element(feature_log_cutpoint_evaluations[j].begin(), feature_log_cutpoint_evaluations[j].end());; + double feature_max_ml = *std::max_element(feature_log_cutpoint_evaluations[j].begin(), feature_log_cutpoint_evaluations[j].end()); + ; largest_ml = std::max(largest_ml, feature_max_ml); } } - std::vector> feature_cutpoint_evaluations(p+1); + std::vector> feature_cutpoint_evaluations(p + 1); for (int j = 0; j < p + 1; j++) { if (feature_log_cutpoint_evaluations[j].size() > 0) { feature_cutpoint_evaluations[j].resize(feature_log_cutpoint_evaluations[j].size()); @@ -641,7 +637,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel } // Compute sum of marginal likelihoods for each feature - std::vector feature_total_cutpoint_evaluations(p+1, 0.0); + std::vector feature_total_cutpoint_evaluations(p + 1, 0.0); for (int j = 0; j < p + 1; j++) { if (feature_log_cutpoint_evaluations[j].size() > 0) { feature_total_cutpoint_evaluations[j] = std::accumulate(feature_cutpoint_evaluations[j].begin(), feature_cutpoint_evaluations[j].end(), 0.0); @@ -655,8 +651,8 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel // Then, sample a cutpoint according to feature_cutpoint_evaluations[feature_chosen] int cutpoint_chosen = sample_discrete_stateless(gen, feature_cutpoint_evaluations[feature_chosen], feature_total_cutpoint_evaluations[feature_chosen]); - - if (feature_chosen == p){ + + if (feature_chosen == p) { // "No split" sampled, don't split or add any nodes to split queue return; } else { @@ -665,14 +661,14 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel FeatureType feature_type = feature_types[feature_split]; double split_value = feature_cutpoint_values[feature_split][cutpoint_chosen]; // Perform all of the relevant "split" operations in the model, tree and training dataset - + // Compute node sample size data_size_t node_n = node_end - node_begin; - + // Actual numeric cutpoint used for ordered categorical and numeric features double split_value_numeric; TreeSplit tree_split; - + // We will use these later in the model expansion data_size_t left_n = 0; data_size_t right_n = 0; @@ -696,7 +692,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel } else { Log::Fatal("Invalid split type"); } - + // Add split to tree and trackers AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true, num_threads); @@ -715,15 +711,15 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel // Add the left and right nodes to the split tracker split_queue.push_front(right_node); - split_queue.push_front(left_node); + split_queue.push_front(left_node); } } } template static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { int root_id = Tree::kRoot; int curr_node_id; @@ -748,9 +744,8 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore std::iota(feature_indices.begin(), feature_indices.end(), 0); std::vector features_selected(num_features_subsample); sample_without_replacement( - features_selected.data(), variable_weights.data(), feature_indices.data(), - p, num_features_subsample, gen - ); + features_selected.data(), variable_weights.data(), feature_indices.data(), + p, num_features_subsample, gen); for (int i = 0; i < p; i++) { feature_subset.at(i) = false; } @@ -778,21 +773,20 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore curr_node_end = begin_end.second; // Draw a split rule at random SampleSplitRule( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, - node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types, - feature_subset, num_threads, leaf_suff_stat_args... - ); + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, + node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types, + feature_subset, num_threads, leaf_suff_stat_args...); } } -/*! +/*! * Runs one iteration of the "grow-from-root" (GFR) sampler for a tree ensemble model, which consists of two steps for every tree in a forest: * 1. Grow a tree by recursively sampling cutpoint via the GFR algorithm * 2. Sampling leaf node parameters, conditional on an updated tree, via a Gibbs sampler - * + * * \tparam LeafModel Leaf model type (i.e. `GaussianConstantLeafModel`, `GaussianUnivariateRegressionLeafModel`, etc...) * \tparam LeafSuffStat Leaf sufficient statistic type (i.e. `GaussianConstantSuffStat`, `GaussianUnivariateRegressionSuffStat`, etc...) - * \tparam LeafSuffStatConstructorArgs Type of constructor arguments used to initialize `LeafSuffStat` class. For `GaussianMultivariateRegressionSuffStat`, + * \tparam LeafSuffStatConstructorArgs Type of constructor arguments used to initialize `LeafSuffStat` class. For `GaussianMultivariateRegressionSuffStat`, * this is `int`, while each of the other three sufficient statistic classes do not take a constructor argument. * \param active_forest Current state of an ensemble from the sampler's perspective. This is managed through an "active forest" class, as distinct from a "forest container" class * of stored ensemble samples because we often wish to update model state without saving the result (e.g. during burn-in or thinning of an MCMC sampler). @@ -817,38 +811,37 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. */ template -static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - std::vector& sweep_update_indices, double global_variance, std::vector& feature_types, int cutpoint_grid_size, +static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + std::vector& sweep_update_indices, double global_variance, std::vector& feature_types, int cutpoint_grid_size, bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the GFR algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { // Adjust any model state needed to run a tree sampler - // For models that involve Bayesian backfitting, this amounts to adding tree i's + // For models that involve Bayesian backfitting, this amounts to adding tree i's // predictions back to the residual (thus, training a model on the "partial residual") // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object Tree* tree = active_forest.GetTree(i); AdjustStateBeforeTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); - + // Reset the tree and sample trackers active_forest.ResetInitTree(i); tracker.ResetRoot(dataset.GetCovariates(), feature_types, i); tree = active_forest.GetTree(i); - + // Sample tree i GFRSampleTreeOneIter( - tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, - variable_weights, i, global_variance, feature_types, cutpoint_grid_size, - num_features_subsample, num_threads, leaf_suff_stat_args... - ); - + tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, + variable_weights, i, global_variance, feature_types, cutpoint_grid_size, + num_features_subsample, num_threads, leaf_suff_stat_args...); + // Sample leaf parameters for tree i tree = active_forest.GetTree(i); leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); - + // Adjust any model state needed to run a tree sampler - // For models that involve Bayesian backfitting, this amounts to subtracting tree i's + // For models that involve Bayesian backfitting, this amounts to subtracting tree i's // predictions back out of the residual (thus, using an updated "partial residual" in the following interation). // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object AdjustStateAfterTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); @@ -860,8 +853,8 @@ static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& } template -static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, +static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, + TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, double global_variance, double prob_grow_old, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Extract dataset information data_size_t n = dataset.GetCovariates().rows(); @@ -870,7 +863,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM int num_leaves = tree->NumLeaves(); std::vector leaves = tree->GetLeaves(); std::vector leaf_weights(num_leaves); - std::fill(leaf_weights.begin(), leaf_weights.end(), 1.0/num_leaves); + std::fill(leaf_weights.begin(), leaf_weights.end(), 1.0 / num_leaves); walker_vose leaf_dist(leaf_weights.begin(), leaf_weights.end()); int leaf_chosen = leaves[leaf_dist(gen)]; int leaf_depth = tree->GetDepth(leaf_chosen); @@ -883,7 +876,6 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM if ((leaf_depth >= max_depth) && (max_depth != -1)) { accept = false; } else { - // Select a split variable at random int p = dataset.GetCovariates().cols(); CHECK_EQ(variable_weights.size(), p); @@ -897,7 +889,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM if (var_max <= var_min) { return; } - + // Split based on var_min to var_max in a given node double split_point_chosen = standard_uniform_draw_53bit(gen) * (var_max - var_min) + var_min; @@ -906,8 +898,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM // Compute the marginal likelihood of split and no split, given the leaf prior std::tuple split_eval = EvaluateProposedSplit( - dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, num_threads, leaf_suff_stat_args... - ); + dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, num_threads, leaf_suff_stat_args...); double split_log_marginal_likelihood = std::get<0>(split_eval); double no_split_log_marginal_likelihood = std::get<1>(split_eval); int32_t left_n = std::get<2>(split_eval); @@ -917,18 +908,17 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf(); bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf(); if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) { - // Determine probability of growing the split node and its two new left and right nodes - double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta()); - double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); - double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); + double pg = tree_prior.GetAlpha() * std::pow(1 + leaf_depth, -tree_prior.GetBeta()); + double pgl = tree_prior.GetAlpha() * std::pow(1 + leaf_depth + 1, -tree_prior.GetBeta()); + double pgr = tree_prior.GetAlpha() * std::pow(1 + leaf_depth + 1, -tree_prior.GetBeta()); // Determine whether a "grow" move is possible from the newly formed tree // in order to compute the probability of choosing "prune" from the new tree // (which is always possible by construction) bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen); - bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf(); - bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf(); + bool min_samples_left_check = left_n >= 2 * tree_prior.GetMinSamplesLeaf(); + bool min_samples_right_check = right_n >= 2 * tree_prior.GetMinSamplesLeaf(); double prob_prune_new; if (non_constant && (min_samples_left_check || min_samples_right_check)) { prob_prune_new = 0.5; @@ -938,14 +928,12 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM // Determine the number of leaves in the current tree and leaf parents in the proposed tree int num_leaf_parents = tree->NumLeafParents(); - double p_leaf = 1/static_cast(num_leaves); - double p_leaf_parent = 1/static_cast(num_leaf_parents+1); + double p_leaf = 1 / static_cast(num_leaves); + double p_leaf_parent = 1 / static_cast(num_leaf_parents + 1); // Compute the final MH ratio - double log_mh_ratio = ( - std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) + - std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood - ); + double log_mh_ratio = (std::log(pg) + std::log(1 - pgl) + std::log(1 - pgr) - std::log(1 - pg) + std::log(prob_prune_new) + + std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood); // Threshold at 0 if (log_mh_ratio > 0) { log_mh_ratio = 0; @@ -967,35 +955,34 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM } template -static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads, +static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Choose a "leaf parent" node at random int num_leaves = tree->NumLeaves(); int num_leaf_parents = tree->NumLeafParents(); std::vector leaf_parents = tree->GetLeafParents(); std::vector leaf_parent_weights(num_leaf_parents); - std::fill(leaf_parent_weights.begin(), leaf_parent_weights.end(), 1.0/num_leaf_parents); + std::fill(leaf_parent_weights.begin(), leaf_parent_weights.end(), 1.0 / num_leaf_parents); walker_vose leaf_parent_dist(leaf_parent_weights.begin(), leaf_parent_weights.end()); int leaf_parent_chosen = leaf_parents[leaf_parent_dist(gen)]; int leaf_parent_depth = tree->GetDepth(leaf_parent_chosen); int left_node = tree->LeftChild(leaf_parent_chosen); int right_node = tree->RightChild(leaf_parent_chosen); int feature_split = tree->SplitIndex(leaf_parent_chosen); - + // Compute the marginal likelihood for the leaf parent and its left and right nodes std::tuple split_eval = EvaluateExistingSplit( - dataset, tracker, residual, leaf_model, global_variance, tree_num, leaf_parent_chosen, left_node, right_node, leaf_suff_stat_args... - ); + dataset, tracker, residual, leaf_model, global_variance, tree_num, leaf_parent_chosen, left_node, right_node, leaf_suff_stat_args...); double split_log_marginal_likelihood = std::get<0>(split_eval); double no_split_log_marginal_likelihood = std::get<1>(split_eval); int32_t left_n = std::get<2>(split_eval); int32_t right_n = std::get<3>(split_eval); - + // Determine probability of growing the split node and its two new left and right nodes - double pg = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth, -tree_prior.GetBeta()); - double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta()); - double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta()); + double pg = tree_prior.GetAlpha() * std::pow(1 + leaf_parent_depth, -tree_prior.GetBeta()); + double pgl = tree_prior.GetAlpha() * std::pow(1 + leaf_parent_depth + 1, -tree_prior.GetBeta()); + double pgr = tree_prior.GetAlpha() * std::pow(1 + leaf_parent_depth + 1, -tree_prior.GetBeta()); // Determine whether a "prune" move is possible from the new tree, // in order to compute the probability of choosing "grow" from the new tree @@ -1020,14 +1007,12 @@ static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, Leaf } // Determine the number of leaves in the current tree and leaf parents in the proposed tree - double p_leaf = 1/static_cast(num_leaves-1); - double p_leaf_parent = 1/static_cast(num_leaf_parents); + double p_leaf = 1 / static_cast(num_leaves - 1); + double p_leaf_parent = 1 / static_cast(num_leaf_parents); // Compute the final MH ratio - double log_mh_ratio = ( - std::log(1-pg) - std::log(pg) - std::log(1-pgl) - std::log(1-pgr) + std::log(prob_prune_old) + - std::log(p_leaf) - std::log(prob_grow_new) - std::log(p_leaf_parent) + no_split_log_marginal_likelihood - split_log_marginal_likelihood - ); + double log_mh_ratio = (std::log(1 - pg) - std::log(pg) - std::log(1 - pgl) - std::log(1 - pgr) + std::log(prob_prune_old) + + std::log(p_leaf) - std::log(prob_grow_new) - std::log(p_leaf_parent) + no_split_log_marginal_likelihood - split_log_marginal_likelihood); // Threshold at 0 if (log_mh_ratio > 0) { log_mh_ratio = 0; @@ -1046,12 +1031,12 @@ static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, Leaf template static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Determine whether it is possible to grow any of the leaves bool grow_possible = false; std::vector leaves = tree->GetLeaves(); - for (auto& leaf: leaves) { + for (auto& leaf : leaves) { if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) { grow_possible = true; break; @@ -1084,15 +1069,13 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For // Draw a split rule at random data_size_t step_chosen = step_dist(gen); bool accept; - + if (step_chosen == 0) { MCMCGrowTreeOneIter( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, num_threads, leaf_suff_stat_args... - ); + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, num_threads, leaf_suff_stat_args...); } else { MCMCPruneTreeOneIter( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, num_threads, leaf_suff_stat_args... - ); + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, num_threads, leaf_suff_stat_args...); } } @@ -1100,10 +1083,10 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For * \brief Runs one iteration of the MCMC sampler for a tree ensemble model, which consists of two steps for every tree in a forest: * 1. Sampling "birth-death" tree modifications via the Metropolis-Hastings algorithm * 2. Sampling leaf node parameters, conditional on a (possibly-updated) tree, via a Gibbs sampler - * + * * \tparam LeafModel Leaf model type (i.e. `GaussianConstantLeafModel`, `GaussianUnivariateRegressionLeafModel`, etc...) * \tparam LeafSuffStat Leaf sufficient statistic type (i.e. `GaussianConstantSuffStat`, `GaussianUnivariateRegressionSuffStat`, etc...) - * \tparam LeafSuffStatConstructorArgs Type of constructor arguments used to initialize `LeafSuffStat` class. For `GaussianMultivariateRegressionSuffStat`, + * \tparam LeafSuffStatConstructorArgs Type of constructor arguments used to initialize `LeafSuffStat` class. For `GaussianMultivariateRegressionSuffStat`, * this is `int`, while each of the other three sufficient statistic classes do not take a constructor argument. * \param active_forest Current state of an ensemble from the sampler's perspective. This is managed through an "active forest" class, as distinct from a "forest container" class * of stored ensemble samples because we often wish to update model state without saving the result (e.g. during burn-in or thinning of an MCMC sampler). @@ -1125,33 +1108,32 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. */ template -static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, +static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the MCMC algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { // Adjust any model state needed to run a tree sampler - // For models that involve Bayesian backfitting, this amounts to adding tree i's + // For models that involve Bayesian backfitting, this amounts to adding tree i's // predictions back to the residual (thus, training a model on the "partial residual") // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object Tree* tree = active_forest.GetTree(i); AdjustStateBeforeTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); - + // Sample tree i tree = active_forest.GetTree(i); MCMCSampleTreeOneIter( - tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, - global_variance, num_threads, leaf_suff_stat_args... - ); - + tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, + global_variance, num_threads, leaf_suff_stat_args...); + // Sample leaf parameters for tree i tree = active_forest.GetTree(i); leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); - + // Adjust any model state needed to run a tree sampler - // For models that involve Bayesian backfitting, this amounts to subtracting tree i's + // For models that involve Bayesian backfitting, this amounts to subtracting tree i's // predictions back out of the residual (thus, using an updated "partial residual" in the following interation). // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object AdjustStateAfterTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); @@ -1162,8 +1144,8 @@ static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& } } -/*! \} */ // end of sampling_group +/*! \} */ // end of sampling_group -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_TREE_SAMPLER_H_ +#endif // STOCHTREE_TREE_SAMPLER_H_ diff --git a/include/stochtree/variance_model.h b/include/stochtree/variance_model.h index af6bbd0d..c738391d 100644 --- a/include/stochtree/variance_model.h +++ b/include/stochtree/variance_model.h @@ -19,7 +19,7 @@ namespace StochTree { /*! \brief Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model */ class GlobalHomoskedasticVarianceModel { public: - GlobalHomoskedasticVarianceModel() {ig_sampler_ = InverseGammaSampler();} + GlobalHomoskedasticVarianceModel() { ig_sampler_ = InverseGammaSampler(); } ~GlobalHomoskedasticVarianceModel() {} double PosteriorShape(Eigen::VectorXd& residuals, double a, double b) { data_size_t n = residuals.rows(); @@ -55,6 +55,7 @@ class GlobalHomoskedasticVarianceModel { double ig_scale = PosteriorScale(residuals, weights, a, b); return ig_sampler_.Sample(ig_shape, ig_scale, gen); } + private: InverseGammaSampler ig_sampler_; }; @@ -62,25 +63,26 @@ class GlobalHomoskedasticVarianceModel { /*! \brief Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model */ class LeafNodeHomoskedasticVarianceModel { public: - LeafNodeHomoskedasticVarianceModel() {ig_sampler_ = InverseGammaSampler();} + LeafNodeHomoskedasticVarianceModel() { ig_sampler_ = InverseGammaSampler(); } ~LeafNodeHomoskedasticVarianceModel() {} double PosteriorShape(TreeEnsemble* ensemble, double a, double b) { data_size_t num_leaves = ensemble->NumLeaves(); - return (a/2.0) + (num_leaves/2.0); + return (a / 2.0) + (num_leaves / 2.0); } double PosteriorScale(TreeEnsemble* ensemble, double a, double b) { double mu_sq = ensemble->SumLeafSquared(); - return (b/2.0) + (mu_sq/2.0); + return (b / 2.0) + (mu_sq / 2.0); } double SampleVarianceParameter(TreeEnsemble* ensemble, double a, double b, std::mt19937& gen) { double ig_shape = PosteriorShape(ensemble, a, b); double ig_scale = PosteriorScale(ensemble, a, b); return ig_sampler_.Sample(ig_shape, ig_scale, gen); } + private: InverseGammaSampler ig_sampler_; }; -} // namespace StochTree +} // namespace StochTree -#endif // STOCHTREE_VARIANCE_MODEL_H_ \ No newline at end of file +#endif // STOCHTREE_VARIANCE_MODEL_H_ \ No newline at end of file From d843d3f5cfa7365d5d8e63b79bd703fdae862038 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 12:17:09 -0400 Subject: [PATCH 05/64] Auto-formatted C++ source files to match style standardized in `.clang-format` --- src/R_data.cpp | 434 ++++++------- src/R_random_effects.cpp | 390 ++++++------ src/R_utils.cpp | 34 +- src/container.cpp | 21 +- src/cutpoint_candidates.cpp | 52 +- src/data.cpp | 12 +- src/forest.cpp | 1141 ++++++++++++++++++----------------- src/kernel.cpp | 59 +- src/leaf_model.cpp | 80 +-- src/ordinal_sampler.cpp | 20 +- src/partition_tracker.cpp | 20 +- src/py_stochtree.cpp | 739 ++++++++++++----------- src/random_effects.cpp | 126 ++-- src/sampler.cpp | 451 +++++++------- src/serialization.cpp | 452 +++++++------- src/tree.cpp | 92 +-- 16 files changed, 2068 insertions(+), 2055 deletions(-) diff --git a/src/R_data.cpp b/src/R_data.cpp index 3e96e0fc..681ca622 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -8,383 +8,383 @@ [[cpp11::register]] cpp11::external_pointer create_forest_dataset_cpp() { - // Create smart pointer to newly allocated object - std::unique_ptr dataset_ptr_ = std::make_unique(); - - // Release management of the pointer to R session - return cpp11::external_pointer(dataset_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr dataset_ptr_ = std::make_unique(); + + // Release management of the pointer to R session + return cpp11::external_pointer(dataset_ptr_.release()); } [[cpp11::register]] int dataset_num_rows_cpp(cpp11::external_pointer dataset) { - return dataset->NumObservations(); + return dataset->NumObservations(); } [[cpp11::register]] int dataset_num_covariates_cpp(cpp11::external_pointer dataset) { - return dataset->NumCovariates(); + return dataset->NumCovariates(); } [[cpp11::register]] int dataset_num_basis_cpp(cpp11::external_pointer dataset) { - return dataset->NumBasis(); + return dataset->NumBasis(); } [[cpp11::register]] bool dataset_has_basis_cpp(cpp11::external_pointer dataset) { - return dataset->HasBasis(); + return dataset->HasBasis(); } [[cpp11::register]] bool dataset_has_variance_weights_cpp(cpp11::external_pointer dataset) { - return dataset->HasVarWeights(); + return dataset->HasVarWeights(); } [[cpp11::register]] void forest_dataset_add_covariates_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> covariates) { - // TODO: add handling code on the R side to ensure matrices are column-major - bool row_major{false}; + // TODO: add handling code on the R side to ensure matrices are column-major + bool row_major{false}; + + // Add covariates + StochTree::data_size_t n = covariates.nrow(); + int num_covariates = covariates.ncol(); + double* covariate_data_ptr = REAL(PROTECT(covariates)); + dataset_ptr->AddCovariates(covariate_data_ptr, n, num_covariates, row_major); - // Add covariates - StochTree::data_size_t n = covariates.nrow(); - int num_covariates = covariates.ncol(); - double* covariate_data_ptr = REAL(PROTECT(covariates)); - dataset_ptr->AddCovariates(covariate_data_ptr, n, num_covariates, row_major); - - // Unprotect pointers to R data - UNPROTECT(1); + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void forest_dataset_add_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> basis) { - // TODO: add handling code on the R side to ensure matrices are column-major - bool row_major{false}; + // TODO: add handling code on the R side to ensure matrices are column-major + bool row_major{false}; - // Add basis - StochTree::data_size_t n = basis.nrow(); - int num_basis = basis.ncol(); - double* basis_data_ptr = REAL(PROTECT(basis)); - dataset_ptr->AddBasis(basis_data_ptr, n, num_basis, row_major); - - // Unprotect pointers to R data - UNPROTECT(1); + // Add basis + StochTree::data_size_t n = basis.nrow(); + int num_basis = basis.ncol(); + double* basis_data_ptr = REAL(PROTECT(basis)); + dataset_ptr->AddBasis(basis_data_ptr, n, num_basis, row_major); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void forest_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> basis) { - // TODO: add handling code on the R side to ensure matrices are column-major - bool row_major{false}; - - // Add basis - StochTree::data_size_t n = basis.nrow(); - int num_basis = basis.ncol(); - double* basis_data_ptr = REAL(PROTECT(basis)); - dataset_ptr->UpdateBasis(basis_data_ptr, n, num_basis, row_major); - - // Unprotect pointers to R data - UNPROTECT(1); + // TODO: add handling code on the R side to ensure matrices are column-major + bool row_major{false}; + + // Add basis + StochTree::data_size_t n = basis.nrow(); + int num_basis = basis.ncol(); + double* basis_data_ptr = REAL(PROTECT(basis)); + dataset_ptr->UpdateBasis(basis_data_ptr, n, num_basis, row_major); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void forest_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate) { - // Add weights - StochTree::data_size_t n = weights.size(); - double* weight_data_ptr = REAL(PROTECT(weights)); - dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate); - - // Unprotect pointers to R data - UNPROTECT(1); + // Add weights + StochTree::data_size_t n = weights.size(); + double* weight_data_ptr = REAL(PROTECT(weights)); + dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void forest_dataset_add_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights) { - // Add weights - StochTree::data_size_t n = weights.size(); - double* weight_data_ptr = REAL(PROTECT(weights)); - dataset_ptr->AddVarianceWeights(weight_data_ptr, n); + // Add weights + StochTree::data_size_t n = weights.size(); + double* weight_data_ptr = REAL(PROTECT(weights)); + dataset_ptr->AddVarianceWeights(weight_data_ptr, n); - // Unprotect pointers to R data - UNPROTECT(1); + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] cpp11::writable::doubles_matrix<> forest_dataset_get_covariates_cpp(cpp11::external_pointer dataset_ptr) { - // Initialize output matrix - int num_row = dataset_ptr->NumObservations(); - int num_col = dataset_ptr->NumCovariates(); - cpp11::writable::doubles_matrix<> output(num_row, num_col); - - for (int i = 0; i < num_row; i++) { - for (int j = 0; j < num_col; j++) { - output(i, j) = dataset_ptr->CovariateValue(i, j); - } + // Initialize output matrix + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumCovariates(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->CovariateValue(i, j); } + } - return output; + return output; } [[cpp11::register]] cpp11::writable::doubles_matrix<> forest_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr) { - // Initialize output matrix - int num_row = dataset_ptr->NumObservations(); - int num_col = dataset_ptr->NumBasis(); - cpp11::writable::doubles_matrix<> output(num_row, num_col); - for (int i = 0; i < num_row; i++) { - for (int j = 0; j < num_col; j++) { - output(i, j) = dataset_ptr->BasisValue(i, j); - } - } - return output; + // Initialize output matrix + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumBasis(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->BasisValue(i, j); + } + } + return output; } [[cpp11::register]] cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr) { - // Initialize output vector - int num_row = dataset_ptr->NumObservations(); - cpp11::writable::doubles output(num_row); - for (int i = 0; i < num_row; i++) { - output.at(i) = dataset_ptr->VarWeightValue(i); - } - return output; + // Initialize output vector + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::doubles output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->VarWeightValue(i); + } + return output; } [[cpp11::register]] bool forest_dataset_has_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_idx) { - return dataset_ptr->HasAuxiliaryDimension(dim_idx); + return dataset_ptr->HasAuxiliaryDimension(dim_idx); } [[cpp11::register]] void forest_dataset_add_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_size) { - dataset_ptr->AddAuxiliaryDimension(dim_size); + dataset_ptr->AddAuxiliaryDimension(dim_size); } [[cpp11::register]] double forest_dataset_get_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx) { - return dataset_ptr->GetAuxiliaryDataValue(dim_idx, element_idx); + return dataset_ptr->GetAuxiliaryDataValue(dim_idx, element_idx); } [[cpp11::register]] void forest_dataset_set_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx, double value) { - dataset_ptr->SetAuxiliaryDataValue(dim_idx, element_idx, value); + dataset_ptr->SetAuxiliaryDataValue(dim_idx, element_idx, value); } [[cpp11::register]] cpp11::writable::doubles forest_dataset_get_auxiliary_data_vector_cpp(cpp11::external_pointer dataset_ptr, int dim_idx) { - const std::vector output_raw = dataset_ptr->GetAuxiliaryDataVector(dim_idx); - int n = output_raw.size(); - cpp11::writable::doubles output(n); - for (int i = 0; i < n; i++) { - output[i] = output_raw[i]; - } - return output; + const std::vector output_raw = dataset_ptr->GetAuxiliaryDataVector(dim_idx); + int n = output_raw.size(); + cpp11::writable::doubles output(n); + for (int i = 0; i < n; i++) { + output[i] = output_raw[i]; + } + return output; } [[cpp11::register]] cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome) { - // Unpack pointers to data and dimensions - StochTree::data_size_t n = outcome.size(); - double* outcome_data_ptr = REAL(PROTECT(outcome)); + // Unpack pointers to data and dimensions + StochTree::data_size_t n = outcome.size(); + double* outcome_data_ptr = REAL(PROTECT(outcome)); + + // Create smart pointer + std::unique_ptr vector_ptr_ = std::make_unique(outcome_data_ptr, n); - // Create smart pointer - std::unique_ptr vector_ptr_ = std::make_unique(outcome_data_ptr, n); - - // Unprotect pointers to R data - UNPROTECT(1); - - // Release management of the pointer to R session - return cpp11::external_pointer(vector_ptr_.release()); + // Unprotect pointers to R data + UNPROTECT(1); + + // Release management of the pointer to R session + return cpp11::external_pointer(vector_ptr_.release()); } [[cpp11::register]] void add_to_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles update_vector) { - // Unpack pointers to data and dimensions - StochTree::data_size_t n = update_vector.size(); - double* update_data_ptr = REAL(PROTECT(update_vector)); - - // Add to the outcome data using the C++ API - outcome->AddToData(update_data_ptr, n); - - // Unprotect pointers to R data - UNPROTECT(1); + // Unpack pointers to data and dimensions + StochTree::data_size_t n = update_vector.size(); + double* update_data_ptr = REAL(PROTECT(update_vector)); + + // Add to the outcome data using the C++ API + outcome->AddToData(update_data_ptr, n); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void subtract_from_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles update_vector) { - // Unpack pointers to data and dimensions - StochTree::data_size_t n = update_vector.size(); - double* update_data_ptr = REAL(PROTECT(update_vector)); - - // Add to the outcome data using the C++ API - outcome->SubtractFromData(update_data_ptr, n); - - // Unprotect pointers to R data - UNPROTECT(1); + // Unpack pointers to data and dimensions + StochTree::data_size_t n = update_vector.size(); + double* update_data_ptr = REAL(PROTECT(update_vector)); + + // Add to the outcome data using the C++ API + outcome->SubtractFromData(update_data_ptr, n); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void overwrite_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles new_vector) { - // Unpack pointers to data and dimensions - StochTree::data_size_t n = new_vector.size(); - double* update_data_ptr = REAL(PROTECT(new_vector)); - - // Add to the outcome data using the C++ API - outcome->OverwriteData(update_data_ptr, n); - - // Unprotect pointers to R data - UNPROTECT(1); + // Unpack pointers to data and dimensions + StochTree::data_size_t n = new_vector.size(); + double* update_data_ptr = REAL(PROTECT(new_vector)); + + // Add to the outcome data using the C++ API + outcome->OverwriteData(update_data_ptr, n); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] -void propagate_trees_column_vector_cpp(cpp11::external_pointer tracker, +void propagate_trees_column_vector_cpp(cpp11::external_pointer tracker, cpp11::external_pointer residual) { - StochTree::UpdateResidualNewOutcome(*tracker, *residual); + StochTree::UpdateResidualNewOutcome(*tracker, *residual); } [[cpp11::register]] cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer vector_ptr) { - // Initialize output vector - StochTree::data_size_t n = vector_ptr->NumRows(); - cpp11::writable::doubles output(n); - - // Unpack data - for (StochTree::data_size_t i = 0; i < n; i++) { - output.at(i) = vector_ptr->GetElement(i); - } - - // Release management of the pointer to R session - return output; + // Initialize output vector + StochTree::data_size_t n = vector_ptr->NumRows(); + cpp11::writable::doubles output(n); + + // Unpack data + for (StochTree::data_size_t i = 0; i < n; i++) { + output.at(i) = vector_ptr->GetElement(i); + } + + // Release management of the pointer to R session + return output; } [[cpp11::register]] cpp11::external_pointer create_rfx_dataset_cpp() { - // Create smart pointer to newly allocated object - std::unique_ptr dataset_ptr_ = std::make_unique(); - - // Release management of the pointer to R session - return cpp11::external_pointer(dataset_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr dataset_ptr_ = std::make_unique(); + + // Release management of the pointer to R session + return cpp11::external_pointer(dataset_ptr_.release()); } [[cpp11::register]] void rfx_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> basis) { - // TODO: add handling code on the R side to ensure matrices are column-major - bool row_major{false}; - - // Add basis - StochTree::data_size_t n = basis.nrow(); - int num_basis = basis.ncol(); - double* basis_data_ptr = REAL(PROTECT(basis)); - dataset_ptr->UpdateBasis(basis_data_ptr, n, num_basis, row_major); - - // Unprotect pointers to R data - UNPROTECT(1); + // TODO: add handling code on the R side to ensure matrices are column-major + bool row_major{false}; + + // Add basis + StochTree::data_size_t n = basis.nrow(); + int num_basis = basis.ncol(); + double* basis_data_ptr = REAL(PROTECT(basis)); + dataset_ptr->UpdateBasis(basis_data_ptr, n, num_basis, row_major); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate) { - // Add weights - StochTree::data_size_t n = weights.size(); - double* weight_data_ptr = REAL(PROTECT(weights)); - dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate); - - // Unprotect pointers to R data - UNPROTECT(1); + // Add weights + StochTree::data_size_t n = weights.size(); + double* weight_data_ptr = REAL(PROTECT(weights)); + dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void rfx_dataset_update_group_labels_cpp(cpp11::external_pointer dataset_ptr, cpp11::integers group_labels) { - // Update group labels - int n = group_labels.size(); - std::vector group_labels_vec(group_labels.begin(), group_labels.end()); - dataset_ptr->UpdateGroupLabels(group_labels_vec, n); + // Update group labels + int n = group_labels.size(); + std::vector group_labels_vec(group_labels.begin(), group_labels.end()); + dataset_ptr->UpdateGroupLabels(group_labels_vec, n); } [[cpp11::register]] int rfx_dataset_num_basis_cpp(cpp11::external_pointer dataset) { - return dataset->NumBases(); + return dataset->NumBases(); } [[cpp11::register]] int rfx_dataset_num_rows_cpp(cpp11::external_pointer dataset) { - return dataset->NumObservations(); + return dataset->NumObservations(); } [[cpp11::register]] bool rfx_dataset_has_group_labels_cpp(cpp11::external_pointer dataset) { - return dataset->HasGroupLabels(); + return dataset->HasGroupLabels(); } [[cpp11::register]] bool rfx_dataset_has_basis_cpp(cpp11::external_pointer dataset) { - return dataset->HasBasis(); + return dataset->HasBasis(); } [[cpp11::register]] bool rfx_dataset_has_variance_weights_cpp(cpp11::external_pointer dataset) { - return dataset->HasVarWeights(); + return dataset->HasVarWeights(); } [[cpp11::register]] void rfx_dataset_add_group_labels_cpp(cpp11::external_pointer dataset_ptr, cpp11::integers group_labels) { - // Add group labels - std::vector group_labels_vec(group_labels.begin(), group_labels.end()); - dataset_ptr->AddGroupLabels(group_labels_vec); + // Add group labels + std::vector group_labels_vec(group_labels.begin(), group_labels.end()); + dataset_ptr->AddGroupLabels(group_labels_vec); } [[cpp11::register]] void rfx_dataset_add_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> basis) { - // TODO: add handling code on the R side to ensure matrices are column-major - bool row_major{false}; - - // Add basis - StochTree::data_size_t n = basis.nrow(); - int num_basis = basis.ncol(); - double* basis_data_ptr = REAL(PROTECT(basis)); - dataset_ptr->AddBasis(basis_data_ptr, n, num_basis, row_major); - - // Unprotect pointers to R data - UNPROTECT(1); + // TODO: add handling code on the R side to ensure matrices are column-major + bool row_major{false}; + + // Add basis + StochTree::data_size_t n = basis.nrow(); + int num_basis = basis.ncol(); + double* basis_data_ptr = REAL(PROTECT(basis)); + dataset_ptr->AddBasis(basis_data_ptr, n, num_basis, row_major); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] void rfx_dataset_add_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights) { - // Add weights - StochTree::data_size_t n = weights.size(); - double* weight_data_ptr = REAL(PROTECT(weights)); - dataset_ptr->AddVarianceWeights(weight_data_ptr, n); - - // Unprotect pointers to R data - UNPROTECT(1); + // Add weights + StochTree::data_size_t n = weights.size(); + double* weight_data_ptr = REAL(PROTECT(weights)); + dataset_ptr->AddVarianceWeights(weight_data_ptr, n); + + // Unprotect pointers to R data + UNPROTECT(1); } [[cpp11::register]] cpp11::writable::integers rfx_dataset_get_group_labels_cpp(cpp11::external_pointer dataset_ptr) { - int num_row = dataset_ptr->NumObservations(); - cpp11::writable::integers output(num_row); - for (int i = 0; i < num_row; i++) { - output.at(i) = dataset_ptr->GroupId(i); - } - return output; + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::integers output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->GroupId(i); + } + return output; } [[cpp11::register]] cpp11::writable::doubles_matrix<> rfx_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr) { - int num_row = dataset_ptr->NumObservations(); - int num_col = dataset_ptr->NumBases(); - cpp11::writable::doubles_matrix<> output(num_row, num_col); - for (int i = 0; i < num_row; i++) { - for (int j = 0; j < num_col; j++) { - output(i, j) = dataset_ptr->BasisValue(i, j); - } - } - return output; + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumBases(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->BasisValue(i, j); + } + } + return output; } [[cpp11::register]] cpp11::writable::doubles rfx_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr) { - int num_row = dataset_ptr->NumObservations(); - cpp11::writable::doubles output(num_row); - for (int i = 0; i < num_row; i++) { - output.at(i) = dataset_ptr->VarWeightValue(i); - } - return output; + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::doubles output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->VarWeightValue(i); + } + return output; } diff --git a/src/R_random_effects.cpp b/src/R_random_effects.cpp index fffba538..cddd0f9e 100644 --- a/src/R_random_effects.cpp +++ b/src/R_random_effects.cpp @@ -11,334 +11,334 @@ [[cpp11::register]] cpp11::external_pointer rfx_container_cpp(int num_components, int num_groups) { - // Create smart pointer to newly allocated object - std::unique_ptr rfx_container_ptr_ = std::make_unique(num_components, num_groups); - - // Release management of the pointer to R session - return cpp11::external_pointer(rfx_container_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr rfx_container_ptr_ = std::make_unique(num_components, num_groups); + + // Release management of the pointer to R session + return cpp11::external_pointer(rfx_container_ptr_.release()); } [[cpp11::register]] cpp11::external_pointer rfx_container_from_json_cpp(cpp11::external_pointer json_ptr, std::string rfx_label) { - // Create smart pointer to newly allocated object - std::unique_ptr rfx_container_ptr_ = std::make_unique(); - - // Extract the random effect container's json - nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); - - // Reset the forest sample container using the json - rfx_container_ptr_->Reset(); - rfx_container_ptr_->from_json(rfx_json); - - // Release management of the pointer to R session - return cpp11::external_pointer(rfx_container_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr rfx_container_ptr_ = std::make_unique(); + + // Extract the random effect container's json + nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + rfx_container_ptr_->Reset(); + rfx_container_ptr_->from_json(rfx_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(rfx_container_ptr_.release()); } [[cpp11::register]] cpp11::external_pointer rfx_label_mapper_from_json_cpp(cpp11::external_pointer json_ptr, std::string rfx_label) { - // Create smart pointer to newly allocated object - std::unique_ptr label_mapper_ptr_ = std::make_unique(); - - // Extract the label mapper's json - nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); - - // Reset the label mapper using the json - label_mapper_ptr_->Reset(); - label_mapper_ptr_->from_json(rfx_json); - - // Release management of the pointer to R session - return cpp11::external_pointer(label_mapper_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr label_mapper_ptr_ = std::make_unique(); + + // Extract the label mapper's json + nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); + + // Reset the label mapper using the json + label_mapper_ptr_->Reset(); + label_mapper_ptr_->from_json(rfx_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(label_mapper_ptr_.release()); } [[cpp11::register]] cpp11::writable::integers rfx_group_ids_from_json_cpp(cpp11::external_pointer json_ptr, std::string rfx_label) { - // Create smart pointer to newly allocated object - cpp11::writable::integers output; - - // Extract the groupids' json - nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); - - // Reset the forest sample container using the json - int num_groups = rfx_json.size(); - for (int i = 0; i < num_groups; i++) { - output.push_back(rfx_json.at(i)); - } - - return output; + // Create smart pointer to newly allocated object + cpp11::writable::integers output; + + // Extract the groupids' json + nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + int num_groups = rfx_json.size(); + for (int i = 0; i < num_groups; i++) { + output.push_back(rfx_json.at(i)); + } + + return output; } [[cpp11::register]] void rfx_container_append_from_json_cpp(cpp11::external_pointer rfx_container_ptr, cpp11::external_pointer json_ptr, std::string rfx_label) { - // Extract the random effect container's json - nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); - - // Reset the forest sample container using the json - rfx_container_ptr->append_from_json(rfx_json); + // Extract the random effect container's json + nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + rfx_container_ptr->append_from_json(rfx_json); } [[cpp11::register]] cpp11::external_pointer rfx_container_from_json_string_cpp(std::string json_string, std::string rfx_label) { - // Create smart pointer to newly allocated object - std::unique_ptr rfx_container_ptr_ = std::make_unique(); - - // Create a nlohmann::json object from the string - nlohmann::json json_object = nlohmann::json::parse(json_string); - - // Extract the random effect container's json - nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); - - // Reset the forest sample container using the json - rfx_container_ptr_->Reset(); - rfx_container_ptr_->from_json(rfx_json); - - // Release management of the pointer to R session - return cpp11::external_pointer(rfx_container_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr rfx_container_ptr_ = std::make_unique(); + + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the random effect container's json + nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + rfx_container_ptr_->Reset(); + rfx_container_ptr_->from_json(rfx_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(rfx_container_ptr_.release()); } [[cpp11::register]] cpp11::external_pointer rfx_label_mapper_from_json_string_cpp(std::string json_string, std::string rfx_label) { - // Create smart pointer to newly allocated object - std::unique_ptr label_mapper_ptr_ = std::make_unique(); - - // Create a nlohmann::json object from the string - nlohmann::json json_object = nlohmann::json::parse(json_string); - - // Extract the label mapper's json - nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); - - // Reset the label mapper using the json - label_mapper_ptr_->Reset(); - label_mapper_ptr_->from_json(rfx_json); - - // Release management of the pointer to R session - return cpp11::external_pointer(label_mapper_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr label_mapper_ptr_ = std::make_unique(); + + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the label mapper's json + nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); + + // Reset the label mapper using the json + label_mapper_ptr_->Reset(); + label_mapper_ptr_->from_json(rfx_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(label_mapper_ptr_.release()); } [[cpp11::register]] cpp11::writable::integers rfx_group_ids_from_json_string_cpp(std::string json_string, std::string rfx_label) { - // Create smart pointer to newly allocated object - cpp11::writable::integers output; - - // Create a nlohmann::json object from the string - nlohmann::json json_object = nlohmann::json::parse(json_string); - - // Extract the groupids' json - nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); - - // Reset the forest sample container using the json - int num_groups = rfx_json.size(); - for (int i = 0; i < num_groups; i++) { - output.push_back(rfx_json.at(i)); - } - - return output; + // Create smart pointer to newly allocated object + cpp11::writable::integers output; + + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the groupids' json + nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + int num_groups = rfx_json.size(); + for (int i = 0; i < num_groups; i++) { + output.push_back(rfx_json.at(i)); + } + + return output; } [[cpp11::register]] void rfx_container_append_from_json_string_cpp(cpp11::external_pointer rfx_container_ptr, std::string json_string, std::string rfx_label) { - // Create a nlohmann::json object from the string - nlohmann::json json_object = nlohmann::json::parse(json_string); - - // Extract the random effect container's json - nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); - - // Reset the forest sample container using the json - rfx_container_ptr->append_from_json(rfx_json); + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the random effect container's json + nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + rfx_container_ptr->append_from_json(rfx_json); } [[cpp11::register]] cpp11::external_pointer rfx_model_cpp(int num_components, int num_groups) { - // Create smart pointer to newly allocated object - std::unique_ptr rfx_model_ptr_ = std::make_unique(num_components, num_groups); - - // Release management of the pointer to R session - return cpp11::external_pointer(rfx_model_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr rfx_model_ptr_ = std::make_unique(num_components, num_groups); + + // Release management of the pointer to R session + return cpp11::external_pointer(rfx_model_ptr_.release()); } [[cpp11::register]] cpp11::external_pointer rfx_tracker_cpp(cpp11::integers group_labels) { - // Convert group_labels to a std::vector - std::vector group_labels_vec(group_labels.begin(), group_labels.end()); - - // Create smart pointer to newly allocated object - std::unique_ptr rfx_tracker_ptr_ = std::make_unique(group_labels_vec); - - // Release management of the pointer to R session - return cpp11::external_pointer(rfx_tracker_ptr_.release()); + // Convert group_labels to a std::vector + std::vector group_labels_vec(group_labels.begin(), group_labels.end()); + + // Create smart pointer to newly allocated object + std::unique_ptr rfx_tracker_ptr_ = std::make_unique(group_labels_vec); + + // Release management of the pointer to R session + return cpp11::external_pointer(rfx_tracker_ptr_.release()); } [[cpp11::register]] cpp11::external_pointer rfx_label_mapper_cpp(cpp11::external_pointer rfx_tracker) { - // Create smart pointer to newly allocated object - std::unique_ptr rfx_label_mapper_ptr_ = std::make_unique(rfx_tracker->GetLabelMap()); - - // Release management of the pointer to R session - return cpp11::external_pointer(rfx_label_mapper_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr rfx_label_mapper_ptr_ = std::make_unique(rfx_tracker->GetLabelMap()); + + // Release management of the pointer to R session + return cpp11::external_pointer(rfx_label_mapper_ptr_.release()); } [[cpp11::register]] -void rfx_model_sample_random_effects_cpp(cpp11::external_pointer rfx_model, cpp11::external_pointer rfx_dataset, - cpp11::external_pointer residual, cpp11::external_pointer rfx_tracker, +void rfx_model_sample_random_effects_cpp(cpp11::external_pointer rfx_model, cpp11::external_pointer rfx_dataset, + cpp11::external_pointer residual, cpp11::external_pointer rfx_tracker, cpp11::external_pointer rfx_container, bool keep_sample, double global_variance, cpp11::external_pointer rng) { - rfx_model->SampleRandomEffects(*rfx_dataset, *residual, *rfx_tracker, global_variance, *rng); - if (keep_sample) rfx_container->AddSample(*rfx_model); + rfx_model->SampleRandomEffects(*rfx_dataset, *residual, *rfx_tracker, global_variance, *rng); + if (keep_sample) rfx_container->AddSample(*rfx_model); } [[cpp11::register]] -cpp11::writable::doubles rfx_model_predict_cpp(cpp11::external_pointer rfx_model, - cpp11::external_pointer rfx_dataset, +cpp11::writable::doubles rfx_model_predict_cpp(cpp11::external_pointer rfx_model, + cpp11::external_pointer rfx_dataset, cpp11::external_pointer rfx_tracker) { - std::vector output = rfx_model->Predict(*rfx_dataset, *rfx_tracker); - return output; + std::vector output = rfx_model->Predict(*rfx_dataset, *rfx_tracker); + return output; } [[cpp11::register]] -cpp11::writable::doubles rfx_container_predict_cpp(cpp11::external_pointer rfx_container, - cpp11::external_pointer rfx_dataset, +cpp11::writable::doubles rfx_container_predict_cpp(cpp11::external_pointer rfx_container, + cpp11::external_pointer rfx_dataset, cpp11::external_pointer label_mapper) { - int num_observations = rfx_dataset->NumObservations(); - int num_samples = rfx_container->NumSamples(); - std::vector output(num_observations*num_samples); - rfx_container->Predict(*rfx_dataset, *label_mapper, output); - return output; + int num_observations = rfx_dataset->NumObservations(); + int num_samples = rfx_container->NumSamples(); + std::vector output(num_observations * num_samples); + rfx_container->Predict(*rfx_dataset, *label_mapper, output); + return output; } [[cpp11::register]] int rfx_container_num_samples_cpp(cpp11::external_pointer rfx_container) { - return rfx_container->NumSamples(); + return rfx_container->NumSamples(); } [[cpp11::register]] int rfx_container_num_components_cpp(cpp11::external_pointer rfx_container) { - return rfx_container->NumComponents(); + return rfx_container->NumComponents(); } [[cpp11::register]] int rfx_container_num_groups_cpp(cpp11::external_pointer rfx_container) { - return rfx_container->NumGroups(); + return rfx_container->NumGroups(); } [[cpp11::register]] void rfx_container_delete_sample_cpp(cpp11::external_pointer rfx_container, int sample_num) { - rfx_container->DeleteSample(sample_num); + rfx_container->DeleteSample(sample_num); } [[cpp11::register]] void rfx_model_set_working_parameter_cpp(cpp11::external_pointer rfx_model, cpp11::doubles working_param_init) { - Eigen::VectorXd working_param_eigen(working_param_init.size()); - for (int i = 0; i < working_param_init.size(); i++) { - working_param_eigen(i) = working_param_init.at(i); - } - rfx_model->SetWorkingParameter(working_param_eigen); + Eigen::VectorXd working_param_eigen(working_param_init.size()); + for (int i = 0; i < working_param_init.size(); i++) { + working_param_eigen(i) = working_param_init.at(i); + } + rfx_model->SetWorkingParameter(working_param_eigen); } [[cpp11::register]] void rfx_model_set_group_parameters_cpp(cpp11::external_pointer rfx_model, cpp11::doubles_matrix<> group_params_init) { - Eigen::MatrixXd group_params_eigen(group_params_init.nrow(), group_params_init.ncol()); - for (int i = 0; i < group_params_init.nrow(); i++) { - for (int j = 0; j < group_params_init.ncol(); j++) { - group_params_eigen(i,j) = group_params_init(i,j); - } + Eigen::MatrixXd group_params_eigen(group_params_init.nrow(), group_params_init.ncol()); + for (int i = 0; i < group_params_init.nrow(); i++) { + for (int j = 0; j < group_params_init.ncol(); j++) { + group_params_eigen(i, j) = group_params_init(i, j); } - rfx_model->SetGroupParameters(group_params_eigen); + } + rfx_model->SetGroupParameters(group_params_eigen); } [[cpp11::register]] void rfx_model_set_working_parameter_covariance_cpp(cpp11::external_pointer rfx_model, cpp11::doubles_matrix<> working_param_cov_init) { - Eigen::MatrixXd working_param_cov_eigen(working_param_cov_init.nrow(), working_param_cov_init.ncol()); - for (int i = 0; i < working_param_cov_init.nrow(); i++) { - for (int j = 0; j < working_param_cov_init.ncol(); j++) { - working_param_cov_eigen(i,j) = working_param_cov_init(i,j); - } + Eigen::MatrixXd working_param_cov_eigen(working_param_cov_init.nrow(), working_param_cov_init.ncol()); + for (int i = 0; i < working_param_cov_init.nrow(); i++) { + for (int j = 0; j < working_param_cov_init.ncol(); j++) { + working_param_cov_eigen(i, j) = working_param_cov_init(i, j); } - rfx_model->SetWorkingParameterCovariance(working_param_cov_eigen); + } + rfx_model->SetWorkingParameterCovariance(working_param_cov_eigen); } [[cpp11::register]] void rfx_model_set_group_parameter_covariance_cpp(cpp11::external_pointer rfx_model, cpp11::doubles_matrix<> group_param_cov_init) { - Eigen::MatrixXd group_param_cov_eigen(group_param_cov_init.nrow(), group_param_cov_init.ncol()); - for (int i = 0; i < group_param_cov_init.nrow(); i++) { - for (int j = 0; j < group_param_cov_init.ncol(); j++) { - group_param_cov_eigen(i,j) = group_param_cov_init(i,j); - } + Eigen::MatrixXd group_param_cov_eigen(group_param_cov_init.nrow(), group_param_cov_init.ncol()); + for (int i = 0; i < group_param_cov_init.nrow(); i++) { + for (int j = 0; j < group_param_cov_init.ncol(); j++) { + group_param_cov_eigen(i, j) = group_param_cov_init(i, j); } - rfx_model->SetGroupParameterCovariance(group_param_cov_eigen); + } + rfx_model->SetGroupParameterCovariance(group_param_cov_eigen); } [[cpp11::register]] void rfx_model_set_variance_prior_shape_cpp(cpp11::external_pointer rfx_model, double shape) { - rfx_model->SetVariancePriorShape(shape); + rfx_model->SetVariancePriorShape(shape); } [[cpp11::register]] void rfx_model_set_variance_prior_scale_cpp(cpp11::external_pointer rfx_model, double scale) { - rfx_model->SetVariancePriorScale(scale); + rfx_model->SetVariancePriorScale(scale); } [[cpp11::register]] cpp11::writable::integers rfx_tracker_get_unique_group_ids_cpp(cpp11::external_pointer rfx_tracker) { - std::vector output = rfx_tracker->GetUniqueGroupIds(); - return output; + std::vector output = rfx_tracker->GetUniqueGroupIds(); + return output; } [[cpp11::register]] cpp11::writable::doubles rfx_container_get_beta_cpp(cpp11::external_pointer rfx_container_ptr) { - return rfx_container_ptr->GetBeta(); + return rfx_container_ptr->GetBeta(); } [[cpp11::register]] cpp11::writable::doubles rfx_container_get_alpha_cpp(cpp11::external_pointer rfx_container_ptr) { - return rfx_container_ptr->GetAlpha(); + return rfx_container_ptr->GetAlpha(); } [[cpp11::register]] cpp11::writable::doubles rfx_container_get_xi_cpp(cpp11::external_pointer rfx_container_ptr) { - return rfx_container_ptr->GetXi(); + return rfx_container_ptr->GetXi(); } [[cpp11::register]] cpp11::writable::doubles rfx_container_get_sigma_cpp(cpp11::external_pointer rfx_container_ptr) { - return rfx_container_ptr->GetSigma(); + return rfx_container_ptr->GetSigma(); } [[cpp11::register]] cpp11::list rfx_label_mapper_to_list_cpp(cpp11::external_pointer label_mapper_ptr) { - cpp11::writable::integers keys; - cpp11::writable::integers values; - std::map label_map = label_mapper_ptr->Map(); - for (const auto& [key, value] : label_map) { - keys.push_back(key); - values.push_back(value); - } - - cpp11::writable::list output; - output.push_back(keys); - output.push_back(values); - return output; + cpp11::writable::integers keys; + cpp11::writable::integers values; + std::map label_map = label_mapper_ptr->Map(); + for (const auto& [key, value] : label_map) { + keys.push_back(key); + values.push_back(value); + } + + cpp11::writable::list output; + output.push_back(keys); + output.push_back(values); + return output; } [[cpp11::register]] -void reset_rfx_model_cpp(cpp11::external_pointer rfx_model, - cpp11::external_pointer rfx_container, +void reset_rfx_model_cpp(cpp11::external_pointer rfx_model, + cpp11::external_pointer rfx_container, int sample_num) { - // Reset the RFX model from a previous sample - rfx_model->ResetFromSample(*rfx_container, sample_num); + // Reset the RFX model from a previous sample + rfx_model->ResetFromSample(*rfx_container, sample_num); } [[cpp11::register]] -void reset_rfx_tracker_cpp(cpp11::external_pointer tracker, - cpp11::external_pointer dataset, - cpp11::external_pointer residual, +void reset_rfx_tracker_cpp(cpp11::external_pointer tracker, + cpp11::external_pointer dataset, + cpp11::external_pointer residual, cpp11::external_pointer rfx_model) { - // Reset the RFX tracker from a previous sample - tracker->ResetFromSample(*rfx_model, *dataset, *residual); + // Reset the RFX tracker from a previous sample + tracker->ResetFromSample(*rfx_model, *dataset, *residual); } [[cpp11::register]] -void root_reset_rfx_tracker_cpp(cpp11::external_pointer tracker, - cpp11::external_pointer dataset, - cpp11::external_pointer residual, +void root_reset_rfx_tracker_cpp(cpp11::external_pointer tracker, + cpp11::external_pointer dataset, + cpp11::external_pointer residual, cpp11::external_pointer rfx_model) { - // Reset the RFX tracker from root - tracker->RootReset(*rfx_model, *dataset, *residual); + // Reset the RFX tracker from root + tracker->RootReset(*rfx_model, *dataset, *residual); } diff --git a/src/R_utils.cpp b/src/R_utils.cpp index 038023cb..8df37da6 100644 --- a/src/R_utils.cpp +++ b/src/R_utils.cpp @@ -3,33 +3,33 @@ [[cpp11::register]] double sum_cpp(cpp11::doubles x) { - double output = 0.0; - for (int i = 0; i < x.size(); i++) { - output += x[i]; - } - return output; + double output = 0.0; + for (int i = 0; i < x.size(); i++) { + output += x[i]; + } + return output; } [[cpp11::register]] double mean_cpp(cpp11::doubles x) { - double output = 0.0; - for (int i = 0; i < x.size(); i++) { - output += x[i]; - } - return output / x.size(); + double output = 0.0; + for (int i = 0; i < x.size(); i++) { + output += x[i]; + } + return output / x.size(); } [[cpp11::register]] double var_cpp(cpp11::doubles x) { - double mean = mean_cpp(x); - double output = 0.0; - for (int i = 0; i < x.size(); i++) { - output += (x[i] - mean) * (x[i] - mean); - } - return output / (x.size() - 1); + double mean = mean_cpp(x); + double output = 0.0; + for (int i = 0; i < x.size(); i++) { + output += (x[i] - mean) * (x[i] - mean); + } + return output / (x.size() - 1); } [[cpp11::register]] double sd_cpp(cpp11::doubles x) { - return std::sqrt(var_cpp(x)); + return std::sqrt(var_cpp(x)); } diff --git a/src/container.cpp b/src/container.cpp index 0d7d3548..999f5be7 100644 --- a/src/container.cpp +++ b/src/container.cpp @@ -48,7 +48,7 @@ void ForestContainer::InitializeRoot(double leaf_value) { CHECK_EQ(forests_.size(), 0); forests_.resize(1); forests_[0].reset(new TreeEnsemble(num_trees_, output_dimension_, is_leaf_constant_, is_exponentiated_)); - // NOTE: not setting num_samples = 1, since we are just initializing constant root + // NOTE: not setting num_samples = 1, since we are just initializing constant root // nodes and the forest still needs to be sampled by either MCMC or GFR num_samples_ = 0; SetLeafValue(0, leaf_value); @@ -60,7 +60,7 @@ void ForestContainer::InitializeRoot(std::vector& leaf_vector) { CHECK_EQ(forests_.size(), 0); forests_.resize(1); forests_[0].reset(new TreeEnsemble(num_trees_, output_dimension_, is_leaf_constant_, is_exponentiated_)); - // NOTE: not setting num_samples = 1, since we are just initializing constant root + // NOTE: not setting num_samples = 1, since we are just initializing constant root // nodes and the forest still needs to be sampled by either MCMC or GFR num_samples_ = 0; SetLeafVector(0, leaf_vector); @@ -78,7 +78,7 @@ void ForestContainer::AddSamples(int num_samples) { std::vector ForestContainer::Predict(ForestDataset& dataset) { data_size_t n = dataset.NumObservations(); - data_size_t total_output_size = n*num_samples_; + data_size_t total_output_size = n * num_samples_; std::vector output(total_output_size); PredictInPlace(dataset, output); return output; @@ -110,7 +110,7 @@ std::vector ForestContainer::PredictRawSingleTree(ForestDataset& dataset void ForestContainer::PredictInPlace(ForestDataset& dataset, std::vector& output) { data_size_t n = dataset.NumObservations(); - data_size_t total_output_size = n*num_samples_; + data_size_t total_output_size = n * num_samples_; CHECK_EQ(total_output_size, output.size()); data_size_t offset = 0; for (int i = 0; i < num_samples_; i++) { @@ -146,14 +146,13 @@ void ForestContainer::PredictRawSingleTreeInPlace(ForestDataset& dataset, int fo data_size_t total_output_size = n * output_dimension_; CHECK_EQ(total_output_size, output.size()); data_size_t offset = 0; - forests_[forest_num]->PredictRawInplace(dataset, output, tree_num, tree_num+1, offset); + forests_[forest_num]->PredictRawInplace(dataset, output, tree_num, tree_num + 1, offset); } void ForestContainer::PredictLeafIndicesInplace( - Eigen::Map>& covariates, - Eigen::Map>& output, - std::vector& forest_indices, int num_trees, data_size_t n -) { + Eigen::Map>& covariates, + Eigen::Map>& output, + std::vector& forest_indices, int num_trees, data_size_t n) { int num_forests = forest_indices.size(); int forest_id; for (int i = 0; i < num_forests; i++) { @@ -177,7 +176,7 @@ json ForestContainer::to_json() { forest_label = "forest_" + std::to_string(i); result_obj.emplace(forest_label, forests_[i]->to_json()); } - + return result_obj; } @@ -222,4 +221,4 @@ void ForestContainer::append_from_json(const json& forest_container_json) { this->num_samples_ += new_num_samples; } -} // namespace StochTree +} // namespace StochTree diff --git a/src/cutpoint_candidates.cpp b/src/cutpoint_candidates.cpp index e43b8219..5c7848f0 100644 --- a/src/cutpoint_candidates.cpp +++ b/src/cutpoint_candidates.cpp @@ -26,15 +26,15 @@ void FeatureCutpointGrid::CalculateStridesNumeric(Eigen::MatrixXd& covariates, E data_size_t node_size = node_end - node_begin; // Check if node has fewer observations than cutpoint_grid_size if (node_size <= cutpoint_grid_size_) { - // In this case it is still possible to have "duplicates" if the values of - // a numeric feature are very close together which in practice will only + // In this case it is still possible to have "duplicates" if the values of + // a numeric feature are very close together which in practice will only // occur when a categorical was imported incorrectly as numeric. - // For this case, we run through the sorted data, determining the stride length + // For this case, we run through the sorted data, determining the stride length // of all unique values. EnumerateNumericCutpointsDeduplication(covariates, residuals, feature_node_sort_tracker, node_id, node_begin, node_end, node_size, feature_index); } else { // Here we must essentially "thin out" the possible cutpoints - // First, we determine a step size that ensures there will be as + // First, we determine a step size that ensures there will be as // many potential cutpoints as articulated in cutpoint_grid_size ScanNumericCutpoints(covariates, residuals, feature_node_sort_tracker, node_id, node_begin, node_end, node_size, feature_index); } @@ -42,7 +42,7 @@ void FeatureCutpointGrid::CalculateStridesNumeric(Eigen::MatrixXd& covariates, E void FeatureCutpointGrid::CalculateStridesOrderedCategorical(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index) { data_size_t node_size = node_end - node_begin; - + // Edge case 1: single observation double single_value; if (node_end - node_begin == 1) { @@ -63,7 +63,7 @@ void FeatureCutpointGrid::CalculateStridesOrderedCategorical(Eigen::MatrixXd& co cutpoint_values_.push_back(static_cast(single_value)); return; } - + // Run the "regular" algorithm for computing categorical strides data_size_t stride_begin = node_begin; data_size_t stride_length = 0; @@ -71,14 +71,14 @@ void FeatureCutpointGrid::CalculateStridesOrderedCategorical(Eigen::MatrixXd& co bool last_element; bool stride_complete; double current_val, next_val; - for (data_size_t i = node_begin; i < node_end; i++){ + for (data_size_t i = node_begin; i < node_end; i++) { current_sort_ind = feature_node_sort_tracker->SortIndex(i, feature_index); current_val = covariates(current_sort_ind, feature_index); last_element = ((i == node_end - 1)); // Increment stride length and bin_sum stride_length += 1; - + if (last_element) { // Update bin vectors node_stride_begin_.push_back(stride_begin); @@ -106,7 +106,7 @@ void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(Eigen::MatrixXd& // TODO: refactor so that this initial code is shared between ordered and unordered categorical cutpoint calculation data_size_t node_size = node_end - node_begin; std::vector bin_sums; - + // Edge case 1: single observation double single_value; if (node_end - node_begin == 1) { @@ -127,7 +127,7 @@ void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(Eigen::MatrixXd& cutpoint_values_.push_back(static_cast(single_value)); return; } - + // Run the "regular" algorithm for computing categorical strides data_size_t stride_begin = node_begin; data_size_t stride_length = 0; @@ -137,11 +137,11 @@ void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(Eigen::MatrixXd& double current_val, next_val; double current_outcome, next_outcome; double bin_sum = 0; - for (data_size_t i = node_begin; i < node_end; i++){ + for (data_size_t i = node_begin; i < node_end; i++) { current_sort_ind = feature_node_sort_tracker->SortIndex(i, feature_index); current_val = covariates(current_sort_ind, feature_index); last_element = ((i == node_end - 1)); - + // Increment stride length and bin_sum stride_length += 1; bin_sum += residuals(current_sort_ind); @@ -156,7 +156,7 @@ void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(Eigen::MatrixXd& next_sort_ind = feature_node_sort_tracker->SortIndex(i + 1, feature_index); next_val = covariates(next_sort_ind, feature_index); stride_complete = (static_cast(next_val) != static_cast(current_val)); - + if (stride_complete) { // Update bin vectors node_stride_begin_.push_back(stride_begin); @@ -173,16 +173,16 @@ void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(Eigen::MatrixXd& } // Now re-arrange the categories according to the average outcome as in Fisher (1958) -// CHECK_EQ(residuals.cols(), 1); + // CHECK_EQ(residuals.cols(), 1); std::vector bin_avgs(bin_sums.size()); for (int i = 0; i < bin_sums.size(); i++) { bin_avgs[i] = bin_sums[i] / node_stride_length_[i]; } std::vector bin_sort_inds(bin_avgs.size()); std::iota(bin_sort_inds.begin(), bin_sort_inds.end(), 0); - auto comp_op = [&](size_t const &l, size_t const &r) { return std::less{}(bin_avgs[l], bin_avgs[r]); }; + auto comp_op = [&](size_t const& l, size_t const& r) { return std::less{}(bin_avgs[l], bin_avgs[r]); }; std::stable_sort(bin_sort_inds.begin(), bin_sort_inds.end(), comp_op); - + std::vector temp_stride_begin_; std::vector temp_stride_length_; std::vector temp_cutpoint_value_; @@ -192,9 +192,9 @@ void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(Eigen::MatrixXd& std::copy(cutpoint_values_.begin(), cutpoint_values_.end(), std::back_inserter(temp_cutpoint_value_)); for (int i = 0; i < node_stride_begin_.size(); i++) { - node_stride_begin_[i] = temp_stride_begin_[bin_sort_inds[i]]; - node_stride_length_[i] = temp_stride_length_[bin_sort_inds[i]]; - cutpoint_values_[i] = temp_cutpoint_value_[bin_sort_inds[i]]; + node_stride_begin_[i] = temp_stride_begin_[bin_sort_inds[i]]; + node_stride_length_[i] = temp_stride_length_[bin_sort_inds[i]]; + cutpoint_values_[i] = temp_cutpoint_value_[bin_sort_inds[i]]; } } @@ -218,7 +218,7 @@ void FeatureCutpointGrid::EnumerateNumericCutpointsDeduplication(Eigen::MatrixXd cutpoint_values_.push_back(first_val); return; } - + // Run the "regular" algorithm for computing categorical strides data_size_t stride_begin = node_begin; data_size_t stride_length = 0; @@ -226,14 +226,14 @@ void FeatureCutpointGrid::EnumerateNumericCutpointsDeduplication(Eigen::MatrixXd bool last_element; bool stride_complete; double current_val, next_val; - for (data_size_t i = node_begin; i < node_end; i++){ + for (data_size_t i = node_begin; i < node_end; i++) { current_sort_ind = feature_node_sort_tracker->SortIndex(i, feature_index); current_val = covariates(current_sort_ind, feature_index); last_element = ((i == node_end - 1)); // Increment stride length stride_length += 1; - + if (last_element) { // Update bin vectors node_stride_begin_.push_back(stride_begin); @@ -277,7 +277,7 @@ void FeatureCutpointGrid::ScanNumericCutpoints(Eigen::MatrixXd& covariates, Eige cutpoint_values_.push_back(first_val); return; } - + // Run the "regular" algorithm for computing categorical strides data_size_t stride_begin = node_begin; data_size_t stride_length = 0; @@ -287,14 +287,14 @@ void FeatureCutpointGrid::ScanNumericCutpoints(Eigen::MatrixXd& covariates, Eige bool bin_complete; double step_size = node_size / cutpoint_grid_size_; double current_val, next_val; - for (data_size_t i = node_begin; i < node_end; i++){ + for (data_size_t i = node_begin; i < node_end; i++) { current_sort_ind = feature_node_sort_tracker->SortIndex(i, feature_index); current_val = covariates(current_sort_ind, feature_index); last_element = ((i == node_end - 1)); // Increment stride length stride_length += 1; - + if (last_element) { // Update bin vectors node_stride_begin_.push_back(stride_begin); @@ -319,4 +319,4 @@ void FeatureCutpointGrid::ScanNumericCutpoints(Eigen::MatrixXd& covariates, Eige } } -} // namespace StochTree +} // namespace StochTree diff --git a/src/data.cpp b/src/data.cpp index e48e9255..66a2fa87 100644 --- a/src/data.cpp +++ b/src/data.cpp @@ -11,14 +11,14 @@ ColumnMatrix::ColumnMatrix(double* data_ptr, data_size_t num_row, int num_col, b ColumnMatrix::ColumnMatrix(std::string filename, std::string column_index_string, bool header, bool precise_float_parser) { // Convert string to vector of indices std::vector column_indices = Str2FeatureVec(column_index_string.c_str()); - + // Set up CSV parser data_size_t num_global_data = 0; auto parser = std::unique_ptr(Parser::CreateParser(filename.c_str(), header, 0, precise_float_parser)); if (parser == nullptr) { Log::Fatal("Could not recognize data format of %s", filename.c_str()); } - + // Determine number of columns in the data file int num_columns = parser->NumFeatures(); @@ -45,7 +45,7 @@ void ColumnMatrix::LoadData(double* data_ptr, data_size_t num_row, int num_col, double temp_value; for (data_size_t i = 0; i < num_row; ++i) { for (int j = 0; j < num_col; ++j) { - if (is_row_major){ + if (is_row_major) { // Numpy 2-d arrays are stored in "row major" order temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); } else { @@ -68,7 +68,7 @@ ColumnVector::ColumnVector(std::string filename, int32_t column_index, bool head if (parser == nullptr) { Log::Fatal("Could not recognize data format of %s", filename.c_str()); } - + // Read data to memory auto text_data = LoadTextDataToMemory(filename.c_str(), &num_global_data, header); int num_observations = static_cast(text_data.size()); @@ -131,7 +131,7 @@ void LoadData(double* data_ptr, int num_row, int num_col, bool is_row_major, Eig double temp_value; for (data_size_t i = 0; i < num_row; ++i) { for (int j = 0; j < num_col; ++j) { - if (is_row_major){ + if (is_row_major) { // Numpy 2-d arrays are stored in "row major" order temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); } else { @@ -154,4 +154,4 @@ void LoadData(double* data_ptr, int num_row, Eigen::VectorXd& data_vector) { } } -} // namespace StochTree +} // namespace StochTree diff --git a/src/forest.cpp b/src/forest.cpp index 26556016..357777e2 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -11,899 +11,914 @@ [[cpp11::register]] cpp11::external_pointer active_forest_cpp(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { - // Create smart pointer to newly allocated object - std::unique_ptr forest_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant, is_exponentiated); - - // Release management of the pointer to R session - return cpp11::external_pointer(forest_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr forest_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant, is_exponentiated); + + // Release management of the pointer to R session + return cpp11::external_pointer(forest_ptr_.release()); } [[cpp11::register]] cpp11::external_pointer forest_container_cpp(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { - // Create smart pointer to newly allocated object - std::unique_ptr forest_sample_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant, is_exponentiated); - - // Release management of the pointer to R session - return cpp11::external_pointer(forest_sample_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr forest_sample_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant, is_exponentiated); + + // Release management of the pointer to R session + return cpp11::external_pointer(forest_sample_ptr_.release()); } [[cpp11::register]] cpp11::external_pointer forest_container_from_json_cpp(cpp11::external_pointer json_ptr, std::string forest_label) { - // Create smart pointer to newly allocated object - std::unique_ptr forest_sample_ptr_ = std::make_unique(0, 1, true); - - // Extract the forest's json - nlohmann::json forest_json = json_ptr->at("forests").at(forest_label); - - // Reset the forest sample container using the json - forest_sample_ptr_->Reset(); - forest_sample_ptr_->from_json(forest_json); - - // Release management of the pointer to R session - return cpp11::external_pointer(forest_sample_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr forest_sample_ptr_ = std::make_unique(0, 1, true); + + // Extract the forest's json + nlohmann::json forest_json = json_ptr->at("forests").at(forest_label); + + // Reset the forest sample container using the json + forest_sample_ptr_->Reset(); + forest_sample_ptr_->from_json(forest_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(forest_sample_ptr_.release()); } [[cpp11::register]] void forest_container_append_from_json_cpp(cpp11::external_pointer forest_sample_ptr, cpp11::external_pointer json_ptr, std::string forest_label) { - // Extract the forest's json - nlohmann::json forest_json = json_ptr->at("forests").at(forest_label); - - // Append to the forest sample container using the json - forest_sample_ptr->append_from_json(forest_json); + // Extract the forest's json + nlohmann::json forest_json = json_ptr->at("forests").at(forest_label); + + // Append to the forest sample container using the json + forest_sample_ptr->append_from_json(forest_json); } [[cpp11::register]] cpp11::external_pointer forest_container_from_json_string_cpp(std::string json_string, std::string forest_label) { - // Create smart pointer to newly allocated object - std::unique_ptr forest_sample_ptr_ = std::make_unique(0, 1, true); - - // Create a nlohmann::json object from the string - nlohmann::json json_object = nlohmann::json::parse(json_string); - - // Extract the forest's json - nlohmann::json forest_json = json_object.at("forests").at(forest_label); - - // Reset the forest sample container using the json - forest_sample_ptr_->Reset(); - forest_sample_ptr_->from_json(forest_json); - - // Release management of the pointer to R session - return cpp11::external_pointer(forest_sample_ptr_.release()); + // Create smart pointer to newly allocated object + std::unique_ptr forest_sample_ptr_ = std::make_unique(0, 1, true); + + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the forest's json + nlohmann::json forest_json = json_object.at("forests").at(forest_label); + + // Reset the forest sample container using the json + forest_sample_ptr_->Reset(); + forest_sample_ptr_->from_json(forest_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(forest_sample_ptr_.release()); } [[cpp11::register]] void forest_merge_cpp(cpp11::external_pointer inbound_forest_ptr, cpp11::external_pointer outbound_forest_ptr) { - inbound_forest_ptr->MergeForest(*outbound_forest_ptr); + inbound_forest_ptr->MergeForest(*outbound_forest_ptr); } [[cpp11::register]] void forest_add_constant_cpp(cpp11::external_pointer forest_ptr, double constant_value) { - forest_ptr->AddValueToLeaves(constant_value); + forest_ptr->AddValueToLeaves(constant_value); } [[cpp11::register]] void forest_multiply_constant_cpp(cpp11::external_pointer forest_ptr, double constant_multiple) { - forest_ptr->MultiplyLeavesByValue(constant_multiple); + forest_ptr->MultiplyLeavesByValue(constant_multiple); } [[cpp11::register]] void forest_container_append_from_json_string_cpp(cpp11::external_pointer forest_sample_ptr, std::string json_string, std::string forest_label) { - // Create a nlohmann::json object from the string - nlohmann::json json_object = nlohmann::json::parse(json_string); - - // Extract the forest's json - nlohmann::json forest_json = json_object.at("forests").at(forest_label); - - // Append to the forest sample container using the json - forest_sample_ptr->append_from_json(forest_json); + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the forest's json + nlohmann::json forest_json = json_object.at("forests").at(forest_label); + + // Append to the forest sample container using the json + forest_sample_ptr->append_from_json(forest_json); } [[cpp11::register]] void combine_forests_forest_container_cpp(cpp11::external_pointer forest_samples, cpp11::integers forest_inds) { - int num_forests = forest_inds.size(); - for (int j = 1; j < num_forests; j++) { - forest_samples->MergeForests(forest_inds[0], forest_inds[j]); - } - // double combined_forest_scale_factor = 1.0 / num_forests; - // forest_samples->MultiplyForest(forest_inds[0], combined_forest_scale_factor); + int num_forests = forest_inds.size(); + for (int j = 1; j < num_forests; j++) { + forest_samples->MergeForests(forest_inds[0], forest_inds[j]); + } + // double combined_forest_scale_factor = 1.0 / num_forests; + // forest_samples->MultiplyForest(forest_inds[0], combined_forest_scale_factor); } [[cpp11::register]] void add_to_forest_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_index, double constant_value) { - forest_samples->AddToForest(forest_index, constant_value); + forest_samples->AddToForest(forest_index, constant_value); } [[cpp11::register]] void multiply_forest_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_index, double constant_multiple) { - forest_samples->MultiplyForest(forest_index, constant_multiple); + forest_samples->MultiplyForest(forest_index, constant_multiple); } [[cpp11::register]] int num_samples_forest_container_cpp(cpp11::external_pointer forest_samples) { - return forest_samples->NumSamples(); + return forest_samples->NumSamples(); } [[cpp11::register]] int ensemble_tree_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int ensemble_num, int tree_num) { - return forest_samples->EnsembleTreeMaxDepth(ensemble_num, tree_num); + return forest_samples->EnsembleTreeMaxDepth(ensemble_num, tree_num); } [[cpp11::register]] double ensemble_average_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int ensemble_num) { - return forest_samples->EnsembleAverageMaxDepth(ensemble_num); + return forest_samples->EnsembleAverageMaxDepth(ensemble_num); } [[cpp11::register]] double average_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples) { - return forest_samples->AverageMaxDepth(); + return forest_samples->AverageMaxDepth(); } [[cpp11::register]] int num_leaves_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) { - StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); - return forest->NumLeaves(); + StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); + return forest->NumLeaves(); } [[cpp11::register]] double sum_leaves_squared_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) { - StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); - return forest->SumLeafSquared(); + StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); + return forest->SumLeafSquared(); } [[cpp11::register]] int num_trees_forest_container_cpp(cpp11::external_pointer forest_samples) { - return forest_samples->NumTrees(); + return forest_samples->NumTrees(); } [[cpp11::register]] void json_save_forest_container_cpp(cpp11::external_pointer forest_samples, std::string json_filename) { - forest_samples->SaveToJsonFile(json_filename); + forest_samples->SaveToJsonFile(json_filename); } [[cpp11::register]] void json_load_forest_container_cpp(cpp11::external_pointer forest_samples, std::string json_filename) { - forest_samples->LoadFromJsonFile(json_filename); + forest_samples->LoadFromJsonFile(json_filename); } [[cpp11::register]] int leaf_dimension_forest_container_cpp(cpp11::external_pointer forest_samples) { - return forest_samples->OutputDimension(); + return forest_samples->OutputDimension(); } [[cpp11::register]] int is_leaf_constant_forest_container_cpp(cpp11::external_pointer forest_samples) { - return forest_samples->IsLeafConstant(); + return forest_samples->IsLeafConstant(); } [[cpp11::register]] int is_exponentiated_forest_container_cpp(cpp11::external_pointer forest_samples) { - return forest_samples->IsExponentiated(); + return forest_samples->IsExponentiated(); } [[cpp11::register]] bool all_roots_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) { - return forest_samples->AllRoots(forest_num); + return forest_samples->AllRoots(forest_num); } [[cpp11::register]] void add_sample_forest_container_cpp(cpp11::external_pointer forest_samples) { - forest_samples->AddSamples(1); + forest_samples->AddSamples(1); } [[cpp11::register]] void set_leaf_value_forest_container_cpp(cpp11::external_pointer forest_samples, double leaf_value) { - forest_samples->InitializeRoot(leaf_value); + forest_samples->InitializeRoot(leaf_value); } [[cpp11::register]] void add_sample_value_forest_container_cpp(cpp11::external_pointer forest_samples, double leaf_value) { - if (forest_samples->OutputDimension() != 1) { - cpp11::stop("leaf_value must match forest leaf dimension"); - } - int num_samples = forest_samples->NumSamples(); - forest_samples->AddSamples(1); - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(num_samples); - int num_trees = ensemble->NumTrees(); - for (int i = 0; i < num_trees; i++) { - StochTree::Tree* tree = ensemble->GetTree(i); - tree->SetLeaf(0, leaf_value); - } + if (forest_samples->OutputDimension() != 1) { + cpp11::stop("leaf_value must match forest leaf dimension"); + } + int num_samples = forest_samples->NumSamples(); + forest_samples->AddSamples(1); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(num_samples); + int num_trees = ensemble->NumTrees(); + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = ensemble->GetTree(i); + tree->SetLeaf(0, leaf_value); + } } [[cpp11::register]] void add_sample_vector_forest_container_cpp(cpp11::external_pointer forest_samples, cpp11::doubles leaf_vector) { - if (forest_samples->OutputDimension() != leaf_vector.size()) { - cpp11::stop("leaf_vector must match forest leaf dimension"); - } - int num_samples = forest_samples->NumSamples(); - forest_samples->AddSamples(1); - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(num_samples); - int num_trees = ensemble->NumTrees(); - std::vector leaf_vector_cast(leaf_vector.begin(), leaf_vector.end()); - for (int i = 0; i < num_trees; i++) { - StochTree::Tree* tree = ensemble->GetTree(i); - tree->SetLeafVector(0, leaf_vector_cast); - } + if (forest_samples->OutputDimension() != leaf_vector.size()) { + cpp11::stop("leaf_vector must match forest leaf dimension"); + } + int num_samples = forest_samples->NumSamples(); + forest_samples->AddSamples(1); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(num_samples); + int num_trees = ensemble->NumTrees(); + std::vector leaf_vector_cast(leaf_vector.begin(), leaf_vector.end()); + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = ensemble->GetTree(i); + tree->SetLeafVector(0, leaf_vector_cast); + } } [[cpp11::register]] void add_numeric_split_tree_value_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int leaf_num, int feature_num, double split_threshold, double left_leaf_value, double right_leaf_value) { - if (forest_samples->OutputDimension() != 1) { - cpp11::stop("leaf_vector must match forest leaf dimension"); - } - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - if (!tree->IsLeaf(leaf_num)) { - cpp11::stop("leaf_num is not a leaf"); - } - tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value); + if (forest_samples->OutputDimension() != 1) { + cpp11::stop("leaf_vector must match forest leaf dimension"); + } + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + if (!tree->IsLeaf(leaf_num)) { + cpp11::stop("leaf_num is not a leaf"); + } + tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value); } [[cpp11::register]] void add_numeric_split_tree_vector_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int leaf_num, int feature_num, double split_threshold, cpp11::doubles left_leaf_vector, cpp11::doubles right_leaf_vector) { - if (forest_samples->OutputDimension() != left_leaf_vector.size()) { - cpp11::stop("left_leaf_vector must match forest leaf dimension"); - } - if (forest_samples->OutputDimension() != right_leaf_vector.size()) { - cpp11::stop("right_leaf_vector must match forest leaf dimension"); - } - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - std::vector left_leaf_vector_cast(left_leaf_vector.begin(), left_leaf_vector.end()); - std::vector right_leaf_vector_cast(right_leaf_vector.begin(), right_leaf_vector.end()); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - if (!tree->IsLeaf(leaf_num)) { - cpp11::stop("leaf_num is not a leaf"); - } - tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_vector_cast, right_leaf_vector_cast); + if (forest_samples->OutputDimension() != left_leaf_vector.size()) { + cpp11::stop("left_leaf_vector must match forest leaf dimension"); + } + if (forest_samples->OutputDimension() != right_leaf_vector.size()) { + cpp11::stop("right_leaf_vector must match forest leaf dimension"); + } + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + std::vector left_leaf_vector_cast(left_leaf_vector.begin(), left_leaf_vector.end()); + std::vector right_leaf_vector_cast(right_leaf_vector.begin(), right_leaf_vector.end()); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + if (!tree->IsLeaf(leaf_num)) { + cpp11::stop("leaf_num is not a leaf"); + } + tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_vector_cast, right_leaf_vector_cast); } [[cpp11::register]] cpp11::writable::integers get_tree_leaves_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - std::vector leaves_raw = tree->GetLeaves(); - cpp11::writable::integers leaves(leaves_raw.begin(), leaves_raw.end()); - return leaves; + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + std::vector leaves_raw = tree->GetLeaves(); + cpp11::writable::integers leaves(leaves_raw.begin(), leaves_raw.end()); + return leaves; } [[cpp11::register]] cpp11::writable::integers get_tree_split_counts_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int num_features) { - cpp11::writable::integers output(num_features); - for (int i = 0; i < output.size(); i++) output.at(i) = 0; - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - std::vector split_nodes = tree->GetInternalNodes(); - for (int i = 0; i < split_nodes.size(); i++) { - auto node_id = split_nodes.at(i); - auto split_feature = tree->SplitIndex(node_id); - output.at(split_feature)++; - } - return output; + cpp11::writable::integers output(num_features); + for (int i = 0; i < output.size(); i++) output.at(i) = 0; + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + std::vector split_nodes = tree->GetInternalNodes(); + for (int i = 0; i < split_nodes.size(); i++) { + auto node_id = split_nodes.at(i); + auto split_feature = tree->SplitIndex(node_id); + output.at(split_feature)++; + } + return output; } [[cpp11::register]] cpp11::writable::integers get_forest_split_counts_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int num_features) { - cpp11::writable::integers output(num_features); - for (int i = 0; i < output.size(); i++) output.at(i) = 0; - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - int num_trees = ensemble->NumTrees(); - for (int i = 0; i < num_trees; i++) { - StochTree::Tree* tree = ensemble->GetTree(i); - std::vector split_nodes = tree->GetInternalNodes(); - for (int j = 0; j < split_nodes.size(); j++) { - auto node_id = split_nodes.at(j); - auto split_feature = tree->SplitIndex(node_id); - output.at(split_feature)++; - } + cpp11::writable::integers output(num_features); + for (int i = 0; i < output.size(); i++) output.at(i) = 0; + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + int num_trees = ensemble->NumTrees(); + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = ensemble->GetTree(i); + std::vector split_nodes = tree->GetInternalNodes(); + for (int j = 0; j < split_nodes.size(); j++) { + auto node_id = split_nodes.at(j); + auto split_feature = tree->SplitIndex(node_id); + output.at(split_feature)++; } - return output; + } + return output; } [[cpp11::register]] cpp11::writable::integers get_overall_split_counts_forest_container_cpp(cpp11::external_pointer forest_samples, int num_features) { - cpp11::writable::integers output(num_features); - for (int i = 0; i < output.size(); i++) output.at(i) = 0; - int num_samples = forest_samples->NumSamples(); - int num_trees = forest_samples->NumTrees(); - for (int i = 0; i < num_samples; i++) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(i); - for (int j = 0; j < num_trees; j++) { - StochTree::Tree* tree = ensemble->GetTree(j); - std::vector split_nodes = tree->GetInternalNodes(); - for (int k = 0; k < split_nodes.size(); k++) { - auto node_id = split_nodes.at(k); - auto split_feature = tree->SplitIndex(node_id); - output.at(split_feature)++; - } - } + cpp11::writable::integers output(num_features); + for (int i = 0; i < output.size(); i++) output.at(i) = 0; + int num_samples = forest_samples->NumSamples(); + int num_trees = forest_samples->NumTrees(); + for (int i = 0; i < num_samples; i++) { + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(i); + for (int j = 0; j < num_trees; j++) { + StochTree::Tree* tree = ensemble->GetTree(j); + std::vector split_nodes = tree->GetInternalNodes(); + for (int k = 0; k < split_nodes.size(); k++) { + auto node_id = split_nodes.at(k); + auto split_feature = tree->SplitIndex(node_id); + output.at(split_feature)++; + } } - return output; + } + return output; } [[cpp11::register]] cpp11::writable::integers get_granular_split_count_array_forest_container_cpp(cpp11::external_pointer forest_samples, int num_features) { - int num_samples = forest_samples->NumSamples(); - int num_trees = forest_samples->NumTrees(); - cpp11::writable::integers output(num_features*num_samples*num_trees); - for (int elem = 0; elem < output.size(); elem++) output.at(elem) = 0; - for (int i = 0; i < num_samples; i++) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(i); - for (int j = 0; j < num_trees; j++) { - StochTree::Tree* tree = ensemble->GetTree(j); - std::vector split_nodes = tree->GetInternalNodes(); - for (int k = 0; k < split_nodes.size(); k++) { - auto node_id = split_nodes.at(k); - auto split_feature = tree->SplitIndex(node_id); - output.at(split_feature*num_samples*num_trees + j*num_samples + i)++; - } - } + int num_samples = forest_samples->NumSamples(); + int num_trees = forest_samples->NumTrees(); + cpp11::writable::integers output(num_features * num_samples * num_trees); + for (int elem = 0; elem < output.size(); elem++) output.at(elem) = 0; + for (int i = 0; i < num_samples; i++) { + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(i); + for (int j = 0; j < num_trees; j++) { + StochTree::Tree* tree = ensemble->GetTree(j); + std::vector split_nodes = tree->GetInternalNodes(); + for (int k = 0; k < split_nodes.size(); k++) { + auto node_id = split_nodes.at(k); + auto split_feature = tree->SplitIndex(node_id); + output.at(split_feature * num_samples * num_trees + j * num_samples + i)++; + } } - return output; + } + return output; } [[cpp11::register]] void set_leaf_vector_forest_container_cpp(cpp11::external_pointer forest_samples, cpp11::doubles leaf_vector) { - std::vector leaf_vector_converted(leaf_vector.size()); - for (int i = 0; i < leaf_vector.size(); i++) { - leaf_vector_converted[i] = leaf_vector[i]; - } - forest_samples->InitializeRoot(leaf_vector_converted); + std::vector leaf_vector_converted(leaf_vector.size()); + for (int i = 0; i < leaf_vector.size(); i++) { + leaf_vector_converted[i] = leaf_vector[i]; + } + forest_samples->InitializeRoot(leaf_vector_converted); } [[cpp11::register]] bool is_leaf_node_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->IsLeaf(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->IsLeaf(node_id); } [[cpp11::register]] bool is_numeric_split_node_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->IsNumericSplitNode(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->IsNumericSplitNode(node_id); } [[cpp11::register]] bool is_categorical_split_node_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->IsCategoricalSplitNode(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->IsCategoricalSplitNode(node_id); } [[cpp11::register]] int parent_node_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->Parent(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->Parent(node_id); } [[cpp11::register]] int left_child_node_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->LeftChild(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->LeftChild(node_id); } [[cpp11::register]] int right_child_node_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->RightChild(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->RightChild(node_id); } [[cpp11::register]] int node_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->GetDepth(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->GetDepth(node_id); } [[cpp11::register]] int split_index_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->SplitIndex(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->SplitIndex(node_id); } [[cpp11::register]] double split_theshold_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->Threshold(node_id); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->Threshold(node_id); } [[cpp11::register]] cpp11::writable::integers split_categories_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - std::vector raw_categories = tree->CategoryList(node_id); - cpp11::writable::integers output(raw_categories.begin(), raw_categories.end()); - return output; + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + std::vector raw_categories = tree->CategoryList(node_id); + cpp11::writable::integers output(raw_categories.begin(), raw_categories.end()); + return output; } [[cpp11::register]] cpp11::writable::doubles leaf_values_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num, int node_id) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - int num_outputs = tree->OutputDimension(); - cpp11::writable::doubles output(num_outputs); - for (int i = 0; i < num_outputs; i++) { - output[i] = tree->LeafValue(node_id, i); - } - return output; + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + int num_outputs = tree->OutputDimension(); + cpp11::writable::doubles output(num_outputs); + for (int i = 0; i < num_outputs; i++) { + output[i] = tree->LeafValue(node_id, i); + } + return output; } [[cpp11::register]] int num_nodes_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->NumValidNodes(); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->NumValidNodes(); } [[cpp11::register]] int num_leaves_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->NumLeaves(); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->NumLeaves(); } [[cpp11::register]] int num_leaf_parents_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->NumLeafParents(); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->NumLeafParents(); } [[cpp11::register]] int num_split_nodes_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - return tree->NumSplitNodes(); + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + return tree->NumSplitNodes(); } [[cpp11::register]] cpp11::writable::integers nodes_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - std::vector leaves = tree->GetNodes(); - cpp11::writable::integers output(leaves.begin(), leaves.end()); - return output; + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + std::vector leaves = tree->GetNodes(); + cpp11::writable::integers output(leaves.begin(), leaves.end()); + return output; } [[cpp11::register]] cpp11::writable::integers leaves_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num, int tree_num) { - StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); - StochTree::Tree* tree = ensemble->GetTree(tree_num); - std::vector leaves = tree->GetLeaves(); - cpp11::writable::integers output(leaves.begin(), leaves.end()); - return output; -} - -[[cpp11::register]] -void initialize_forest_model_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer tracker, - cpp11::doubles init_values, int leaf_model_int){ - // Convert leaf model type to enum - StochTree::ModelType model_type; - if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; - else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; - else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; - else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; - else StochTree::Log::Fatal("Invalid model type"); - - // Unpack initial value - int num_trees = forest_samples->NumTrees(); - double init_val; - std::vector init_value_vector; - if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || - (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) || - (model_type == StochTree::ModelType::kLogLinearVariance)) { - init_val = init_values.at(0); - } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - int leaf_dim = init_values.size(); - init_value_vector.resize(leaf_dim); - for (int i = 0; i < leaf_dim; i++) { - init_value_vector[i] = init_values[i] / static_cast(num_trees); - } - } - - // Initialize the models accordingly - if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - forest_samples->InitializeRoot(init_val / static_cast(num_trees)); - UpdateResidualEntireForest(*tracker, *data, *residual, forest_samples->GetEnsemble(0), false, std::minus()); - tracker->UpdatePredictions(forest_samples->GetEnsemble(0), *data); - } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - forest_samples->InitializeRoot(init_val / static_cast(num_trees)); - UpdateResidualEntireForest(*tracker, *data, *residual, forest_samples->GetEnsemble(0), true, std::minus()); - tracker->UpdatePredictions(forest_samples->GetEnsemble(0), *data); - } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - forest_samples->InitializeRoot(init_value_vector); - UpdateResidualEntireForest(*tracker, *data, *residual, forest_samples->GetEnsemble(0), true, std::minus()); - tracker->UpdatePredictions(forest_samples->GetEnsemble(0), *data); - } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - forest_samples->InitializeRoot(std::log(init_val) / static_cast(num_trees)); - tracker->UpdatePredictions(forest_samples->GetEnsemble(0), *data); - int n = data->NumObservations(); - std::vector initial_preds(n, init_val); - data->AddVarianceWeights(initial_preds.data(), n); - } -} - -[[cpp11::register]] -void adjust_residual_forest_container_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer tracker, + StochTree::TreeEnsemble* ensemble = forest_samples->GetEnsemble(forest_num); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + std::vector leaves = tree->GetLeaves(); + cpp11::writable::integers output(leaves.begin(), leaves.end()); + return output; +} + +[[cpp11::register]] +void initialize_forest_model_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer tracker, + cpp11::doubles init_values, int leaf_model_int) { + // Convert leaf model type to enum + StochTree::ModelType model_type; + if (leaf_model_int == 0) + model_type = StochTree::ModelType::kConstantLeafGaussian; + else if (leaf_model_int == 1) + model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; + else if (leaf_model_int == 2) + model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; + else if (leaf_model_int == 3) + model_type = StochTree::ModelType::kLogLinearVariance; + else + StochTree::Log::Fatal("Invalid model type"); + + // Unpack initial value + int num_trees = forest_samples->NumTrees(); + double init_val; + std::vector init_value_vector; + if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || + (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) || + (model_type == StochTree::ModelType::kLogLinearVariance)) { + init_val = init_values.at(0); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + int leaf_dim = init_values.size(); + init_value_vector.resize(leaf_dim); + for (int i = 0; i < leaf_dim; i++) { + init_value_vector[i] = init_values[i] / static_cast(num_trees); + } + } + + // Initialize the models accordingly + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + forest_samples->InitializeRoot(init_val / static_cast(num_trees)); + UpdateResidualEntireForest(*tracker, *data, *residual, forest_samples->GetEnsemble(0), false, std::minus()); + tracker->UpdatePredictions(forest_samples->GetEnsemble(0), *data); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + forest_samples->InitializeRoot(init_val / static_cast(num_trees)); + UpdateResidualEntireForest(*tracker, *data, *residual, forest_samples->GetEnsemble(0), true, std::minus()); + tracker->UpdatePredictions(forest_samples->GetEnsemble(0), *data); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + forest_samples->InitializeRoot(init_value_vector); + UpdateResidualEntireForest(*tracker, *data, *residual, forest_samples->GetEnsemble(0), true, std::minus()); + tracker->UpdatePredictions(forest_samples->GetEnsemble(0), *data); + } else if (model_type == StochTree::ModelType::kLogLinearVariance) { + forest_samples->InitializeRoot(std::log(init_val) / static_cast(num_trees)); + tracker->UpdatePredictions(forest_samples->GetEnsemble(0), *data); + int n = data->NumObservations(); + std::vector initial_preds(n, init_val); + data->AddVarianceWeights(initial_preds.data(), n); + } +} + +[[cpp11::register]] +void adjust_residual_forest_container_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer tracker, bool requires_basis, int forest_num, bool add) { - // Determine whether or not we are adding forest_num to the residuals - std::function op; - if (add) op = std::plus(); - else op = std::minus(); - - // Perform the update (addition / subtraction) operation - StochTree::UpdateResidualEntireForest(*tracker, *data, *residual, forest_samples->GetEnsemble(forest_num), requires_basis, op); + // Determine whether or not we are adding forest_num to the residuals + std::function op; + if (add) + op = std::plus(); + else + op = std::minus(); + + // Perform the update (addition / subtraction) operation + StochTree::UpdateResidualEntireForest(*tracker, *data, *residual, forest_samples->GetEnsemble(forest_num), requires_basis, op); } [[cpp11::register]] -void propagate_basis_update_forest_container_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer tracker, +void propagate_basis_update_forest_container_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer tracker, int forest_num) { - // Perform the update (addition / subtraction) operation - StochTree::UpdateResidualNewBasis(*tracker, *data, *residual, forest_samples->GetEnsemble(forest_num)); + // Perform the update (addition / subtraction) operation + StochTree::UpdateResidualNewBasis(*tracker, *data, *residual, forest_samples->GetEnsemble(forest_num)); } [[cpp11::register]] -void remove_sample_forest_container_cpp(cpp11::external_pointer forest_samples, +void remove_sample_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) { - forest_samples->DeleteSample(forest_num); + forest_samples->DeleteSample(forest_num); } [[cpp11::register]] cpp11::writable::doubles_matrix<> predict_forest_cpp(cpp11::external_pointer forest_samples, cpp11::external_pointer dataset) { - // Predict from the sampled forests - std::vector output_raw = forest_samples->Predict(*dataset); - - // Convert result to a matrix - int n = dataset->GetCovariates().rows(); - int num_samples = forest_samples->NumSamples(); - cpp11::writable::doubles_matrix<> output(n, num_samples); - for (size_t i = 0; i < n; i++) { - for (int j = 0; j < num_samples; j++) { - output(i, j) = output_raw[n*j + i]; - } + // Predict from the sampled forests + std::vector output_raw = forest_samples->Predict(*dataset); + + // Convert result to a matrix + int n = dataset->GetCovariates().rows(); + int num_samples = forest_samples->NumSamples(); + cpp11::writable::doubles_matrix<> output(n, num_samples); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < num_samples; j++) { + output(i, j) = output_raw[n * j + i]; } - - return output; + } + + return output; } [[cpp11::register]] cpp11::writable::doubles predict_forest_raw_cpp(cpp11::external_pointer forest_samples, cpp11::external_pointer dataset) { - // Predict from the sampled forests - std::vector output_raw = forest_samples->PredictRaw(*dataset); - - // Unpack / re-arrange results - int n = dataset->GetCovariates().rows(); - int num_samples = forest_samples->NumSamples(); - int output_dimension = forest_samples->OutputDimension(); - cpp11::writable::doubles output(n*output_dimension*num_samples); - for (size_t i = 0; i < n; i++) { - for (int j = 0; j < output_dimension; j++) { - for (int k = 0; k < num_samples; k++) { - // Convert from idiosyncratic C++ storage to "column-major" --- first dimension is data row, second is output column, third is sample number - output.at(k*output_dimension*n + j*n + i) = output_raw[k*output_dimension*n + i*output_dimension + j]; - } - } + // Predict from the sampled forests + std::vector output_raw = forest_samples->PredictRaw(*dataset); + + // Unpack / re-arrange results + int n = dataset->GetCovariates().rows(); + int num_samples = forest_samples->NumSamples(); + int output_dimension = forest_samples->OutputDimension(); + cpp11::writable::doubles output(n * output_dimension * num_samples); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < output_dimension; j++) { + for (int k = 0; k < num_samples; k++) { + // Convert from idiosyncratic C++ storage to "column-major" --- first dimension is data row, second is output column, third is sample number + output.at(k * output_dimension * n + j * n + i) = output_raw[k * output_dimension * n + i * output_dimension + j]; + } } - - return output; + } + + return output; } [[cpp11::register]] cpp11::writable::doubles_matrix<> predict_forest_raw_single_forest_cpp(cpp11::external_pointer forest_samples, cpp11::external_pointer dataset, int forest_num) { - // Predict from the sampled forests - std::vector output_raw = forest_samples->PredictRaw(*dataset, forest_num); - - // Convert result to a matrix - int n = dataset->GetCovariates().rows(); - int output_dimension = forest_samples->OutputDimension(); - cpp11::writable::doubles_matrix<> output(n, output_dimension); - for (size_t i = 0; i < n; i++) { - for (int j = 0; j < output_dimension; j++) { - output(i, j) = output_raw[i*output_dimension + j]; - } + // Predict from the sampled forests + std::vector output_raw = forest_samples->PredictRaw(*dataset, forest_num); + + // Convert result to a matrix + int n = dataset->GetCovariates().rows(); + int output_dimension = forest_samples->OutputDimension(); + cpp11::writable::doubles_matrix<> output(n, output_dimension); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < output_dimension; j++) { + output(i, j) = output_raw[i * output_dimension + j]; } - - return output; + } + + return output; } [[cpp11::register]] cpp11::writable::doubles_matrix<> predict_forest_raw_single_tree_cpp(cpp11::external_pointer forest_samples, cpp11::external_pointer dataset, int forest_num, int tree_num) { - // Predict from the sampled forests - std::vector output_raw = forest_samples->PredictRawSingleTree(*dataset, forest_num, tree_num); - - // Convert result to a matrix - int n = dataset->GetCovariates().rows(); - int output_dimension = forest_samples->OutputDimension(); - cpp11::writable::doubles_matrix<> output(n, output_dimension); - for (size_t i = 0; i < n; i++) { - for (int j = 0; j < output_dimension; j++) { - output(i, j) = output_raw[i*output_dimension + j]; - } + // Predict from the sampled forests + std::vector output_raw = forest_samples->PredictRawSingleTree(*dataset, forest_num, tree_num); + + // Convert result to a matrix + int n = dataset->GetCovariates().rows(); + int output_dimension = forest_samples->OutputDimension(); + cpp11::writable::doubles_matrix<> output(n, output_dimension); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < output_dimension; j++) { + output(i, j) = output_raw[i * output_dimension + j]; } - return output; + } + return output; } [[cpp11::register]] cpp11::writable::doubles predict_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer dataset) { - int n = dataset->GetCovariates().rows(); - std::vector output(n); - active_forest->PredictInplace(*dataset, output, 0); - return output; + int n = dataset->GetCovariates().rows(); + std::vector output(n); + active_forest->PredictInplace(*dataset, output, 0); + return output; } [[cpp11::register]] cpp11::writable::doubles predict_raw_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer dataset) { - int n = dataset->GetCovariates().rows(); - int output_dimension = active_forest->OutputDimension(); - std::vector output_raw(n*output_dimension); - active_forest->PredictRawInplace(*dataset, output_raw, 0); - - cpp11::writable::doubles output(n*output_dimension); - for (size_t i = 0; i < n; i++) { - for (int j = 0; j < output_dimension; j++) { - // Convert from row-major to column-major - output.at(j*n + i) = output_raw[i*output_dimension + j]; - } + int n = dataset->GetCovariates().rows(); + int output_dimension = active_forest->OutputDimension(); + std::vector output_raw(n * output_dimension); + active_forest->PredictRawInplace(*dataset, output_raw, 0); + + cpp11::writable::doubles output(n * output_dimension); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < output_dimension; j++) { + // Convert from row-major to column-major + output.at(j * n + i) = output_raw[i * output_dimension + j]; } - - return output; + } + + return output; } [[cpp11::register]] int leaf_dimension_active_forest_cpp(cpp11::external_pointer active_forest) { - return active_forest->OutputDimension(); + return active_forest->OutputDimension(); } [[cpp11::register]] double average_max_depth_active_forest_cpp(cpp11::external_pointer active_forest) { - return active_forest->AverageMaxDepth(); + return active_forest->AverageMaxDepth(); } [[cpp11::register]] int num_trees_active_forest_cpp(cpp11::external_pointer active_forest) { - return active_forest->NumTrees(); + return active_forest->NumTrees(); } [[cpp11::register]] int ensemble_tree_max_depth_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num) { - return active_forest->TreeMaxDepth(tree_num); + return active_forest->TreeMaxDepth(tree_num); } [[cpp11::register]] int is_leaf_constant_active_forest_cpp(cpp11::external_pointer active_forest) { - return active_forest->IsLeafConstant(); + return active_forest->IsLeafConstant(); } [[cpp11::register]] int is_exponentiated_active_forest_cpp(cpp11::external_pointer active_forest) { - return active_forest->IsExponentiated(); + return active_forest->IsExponentiated(); } [[cpp11::register]] bool all_roots_active_forest_cpp(cpp11::external_pointer active_forest) { - return active_forest->AllRoots(); + return active_forest->AllRoots(); } [[cpp11::register]] void set_leaf_value_active_forest_cpp(cpp11::external_pointer active_forest, double leaf_value) { - active_forest->SetLeafValue(leaf_value); + active_forest->SetLeafValue(leaf_value); } [[cpp11::register]] void set_leaf_vector_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::doubles leaf_vector) { - std::vector leaf_vector_cast(leaf_vector.begin(), leaf_vector.end()); - active_forest->SetLeafVector(leaf_vector_cast); + std::vector leaf_vector_cast(leaf_vector.begin(), leaf_vector.end()); + active_forest->SetLeafVector(leaf_vector_cast); } [[cpp11::register]] void add_numeric_split_tree_value_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int leaf_num, int feature_num, double split_threshold, double left_leaf_value, double right_leaf_value) { - if (active_forest->OutputDimension() != 1) { - cpp11::stop("leaf_vector must match forest leaf dimension"); - } - StochTree::Tree* tree = active_forest->GetTree(tree_num); - if (!tree->IsLeaf(leaf_num)) { - cpp11::stop("leaf_num is not a leaf"); - } - tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value); + if (active_forest->OutputDimension() != 1) { + cpp11::stop("leaf_vector must match forest leaf dimension"); + } + StochTree::Tree* tree = active_forest->GetTree(tree_num); + if (!tree->IsLeaf(leaf_num)) { + cpp11::stop("leaf_num is not a leaf"); + } + tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value); } [[cpp11::register]] void add_numeric_split_tree_vector_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int leaf_num, int feature_num, double split_threshold, cpp11::doubles left_leaf_vector, cpp11::doubles right_leaf_vector) { - if (active_forest->OutputDimension() != left_leaf_vector.size()) { - cpp11::stop("left_leaf_vector must match forest leaf dimension"); - } - if (active_forest->OutputDimension() != right_leaf_vector.size()) { - cpp11::stop("right_leaf_vector must match forest leaf dimension"); - } - std::vector left_leaf_vector_cast(left_leaf_vector.begin(), left_leaf_vector.end()); - std::vector right_leaf_vector_cast(right_leaf_vector.begin(), right_leaf_vector.end()); - StochTree::Tree* tree = active_forest->GetTree(tree_num); - if (!tree->IsLeaf(leaf_num)) { - cpp11::stop("leaf_num is not a leaf"); - } - tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_vector_cast, right_leaf_vector_cast); + if (active_forest->OutputDimension() != left_leaf_vector.size()) { + cpp11::stop("left_leaf_vector must match forest leaf dimension"); + } + if (active_forest->OutputDimension() != right_leaf_vector.size()) { + cpp11::stop("right_leaf_vector must match forest leaf dimension"); + } + std::vector left_leaf_vector_cast(left_leaf_vector.begin(), left_leaf_vector.end()); + std::vector right_leaf_vector_cast(right_leaf_vector.begin(), right_leaf_vector.end()); + StochTree::Tree* tree = active_forest->GetTree(tree_num); + if (!tree->IsLeaf(leaf_num)) { + cpp11::stop("leaf_num is not a leaf"); + } + tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_vector_cast, right_leaf_vector_cast); } [[cpp11::register]] cpp11::writable::integers get_tree_leaves_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num) { - StochTree::Tree* tree = active_forest->GetTree(tree_num); - std::vector leaves_raw = tree->GetLeaves(); - cpp11::writable::integers leaves(leaves_raw.begin(), leaves_raw.end()); - return leaves; + StochTree::Tree* tree = active_forest->GetTree(tree_num); + std::vector leaves_raw = tree->GetLeaves(); + cpp11::writable::integers leaves(leaves_raw.begin(), leaves_raw.end()); + return leaves; } [[cpp11::register]] cpp11::writable::integers get_tree_split_counts_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int num_features) { - cpp11::writable::integers output(num_features); - for (int i = 0; i < output.size(); i++) output.at(i) = 0; - StochTree::Tree* tree = active_forest->GetTree(tree_num); - std::vector split_nodes = tree->GetInternalNodes(); - for (int i = 0; i < split_nodes.size(); i++) { - auto node_id = split_nodes.at(i); - auto feature_split = tree->SplitIndex(node_id); - output.at(feature_split)++; - } - return output; + cpp11::writable::integers output(num_features); + for (int i = 0; i < output.size(); i++) output.at(i) = 0; + StochTree::Tree* tree = active_forest->GetTree(tree_num); + std::vector split_nodes = tree->GetInternalNodes(); + for (int i = 0; i < split_nodes.size(); i++) { + auto node_id = split_nodes.at(i); + auto feature_split = tree->SplitIndex(node_id); + output.at(feature_split)++; + } + return output; } [[cpp11::register]] cpp11::writable::integers get_overall_split_counts_active_forest_cpp(cpp11::external_pointer active_forest, int num_features) { - cpp11::writable::integers output(num_features); - for (int i = 0; i < output.size(); i++) output.at(i) = 0; - int num_trees = active_forest->NumTrees(); - for (int i = 0; i < num_trees; i++) { - StochTree::Tree* tree = active_forest->GetTree(i); - std::vector split_nodes = tree->GetInternalNodes(); - for (int j = 0; j < split_nodes.size(); j++) { - auto node_id = split_nodes.at(j); - auto feature_split = tree->SplitIndex(node_id); - output.at(feature_split)++; - } + cpp11::writable::integers output(num_features); + for (int i = 0; i < output.size(); i++) output.at(i) = 0; + int num_trees = active_forest->NumTrees(); + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = active_forest->GetTree(i); + std::vector split_nodes = tree->GetInternalNodes(); + for (int j = 0; j < split_nodes.size(); j++) { + auto node_id = split_nodes.at(j); + auto feature_split = tree->SplitIndex(node_id); + output.at(feature_split)++; } - return output; + } + return output; } [[cpp11::register]] cpp11::writable::integers get_granular_split_count_array_active_forest_cpp(cpp11::external_pointer active_forest, int num_features) { - int num_trees = active_forest->NumTrees(); - cpp11::writable::integers output(num_features*num_trees); - for (int elem = 0; elem < output.size(); elem++) output.at(elem) = 0; - for (int i = 0; i < num_trees; i++) { - StochTree::Tree* tree = active_forest->GetTree(i); - std::vector split_nodes = tree->GetInternalNodes(); - for (int j = 0; j < split_nodes.size(); j++) { - auto node_id = split_nodes.at(j); - auto feature_split = tree->SplitIndex(node_id); - output.at(feature_split*num_trees + i)++; - } - } - return output; -} - -[[cpp11::register]] -void initialize_forest_model_active_forest_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer active_forest, - cpp11::external_pointer tracker, - cpp11::doubles init_values, int leaf_model_int){ - // Convert leaf model type to enum - StochTree::ModelType model_type; - if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; - else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; - else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; - else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; - else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; - else StochTree::Log::Fatal("Invalid model type"); - - // Unpack initial value - int num_trees = active_forest->NumTrees(); - double init_val; - std::vector init_value_vector; - if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || - (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) || - (model_type == StochTree::ModelType::kLogLinearVariance)) { - init_val = init_values.at(0); - } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - int leaf_dim = init_values.size(); - init_value_vector.resize(leaf_dim); - for (int i = 0; i < leaf_dim; i++) { - init_value_vector[i] = init_values[i] / static_cast(num_trees); - } - } - - // Initialize the models accordingly - double leaf_init_val; - if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - leaf_init_val = init_val / static_cast(num_trees); - active_forest->SetLeafValue(leaf_init_val); - UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), false, std::minus()); - tracker->UpdatePredictions(active_forest.get(), *data); - } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - leaf_init_val = init_val / static_cast(num_trees); - active_forest->SetLeafValue(leaf_init_val); - UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), true, std::minus()); - tracker->UpdatePredictions(active_forest.get(), *data); - } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - active_forest->SetLeafVector(init_value_vector); - UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), true, std::minus()); - tracker->UpdatePredictions(active_forest.get(), *data); - } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - leaf_init_val = std::log(init_val) / static_cast(num_trees); - active_forest->SetLeafValue(leaf_init_val); - tracker->UpdatePredictions(active_forest.get(), *data); - int n = data->NumObservations(); - std::vector initial_preds(n, init_val); - data->AddVarianceWeights(initial_preds.data(), n); - } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - leaf_init_val = init_val / static_cast(num_trees); - active_forest->SetLeafValue(leaf_init_val); - UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), false, std::minus()); - tracker->UpdatePredictions(active_forest.get(), *data); - } -} - -[[cpp11::register]] -void adjust_residual_active_forest_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer active_forest, - cpp11::external_pointer tracker, + int num_trees = active_forest->NumTrees(); + cpp11::writable::integers output(num_features * num_trees); + for (int elem = 0; elem < output.size(); elem++) output.at(elem) = 0; + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = active_forest->GetTree(i); + std::vector split_nodes = tree->GetInternalNodes(); + for (int j = 0; j < split_nodes.size(); j++) { + auto node_id = split_nodes.at(j); + auto feature_split = tree->SplitIndex(node_id); + output.at(feature_split * num_trees + i)++; + } + } + return output; +} + +[[cpp11::register]] +void initialize_forest_model_active_forest_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + cpp11::doubles init_values, int leaf_model_int) { + // Convert leaf model type to enum + StochTree::ModelType model_type; + if (leaf_model_int == 0) + model_type = StochTree::ModelType::kConstantLeafGaussian; + else if (leaf_model_int == 1) + model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; + else if (leaf_model_int == 2) + model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; + else if (leaf_model_int == 3) + model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) + model_type = StochTree::ModelType::kCloglogOrdinal; + else + StochTree::Log::Fatal("Invalid model type"); + + // Unpack initial value + int num_trees = active_forest->NumTrees(); + double init_val; + std::vector init_value_vector; + if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || + (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) || + (model_type == StochTree::ModelType::kLogLinearVariance)) { + init_val = init_values.at(0); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + int leaf_dim = init_values.size(); + init_value_vector.resize(leaf_dim); + for (int i = 0; i < leaf_dim; i++) { + init_value_vector[i] = init_values[i] / static_cast(num_trees); + } + } + + // Initialize the models accordingly + double leaf_init_val; + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + leaf_init_val = init_val / static_cast(num_trees); + active_forest->SetLeafValue(leaf_init_val); + UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), false, std::minus()); + tracker->UpdatePredictions(active_forest.get(), *data); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + leaf_init_val = init_val / static_cast(num_trees); + active_forest->SetLeafValue(leaf_init_val); + UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), true, std::minus()); + tracker->UpdatePredictions(active_forest.get(), *data); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + active_forest->SetLeafVector(init_value_vector); + UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), true, std::minus()); + tracker->UpdatePredictions(active_forest.get(), *data); + } else if (model_type == StochTree::ModelType::kLogLinearVariance) { + leaf_init_val = std::log(init_val) / static_cast(num_trees); + active_forest->SetLeafValue(leaf_init_val); + tracker->UpdatePredictions(active_forest.get(), *data); + int n = data->NumObservations(); + std::vector initial_preds(n, init_val); + data->AddVarianceWeights(initial_preds.data(), n); + } else if (model_type == StochTree::ModelType::kLogLinearVariance) { + leaf_init_val = init_val / static_cast(num_trees); + active_forest->SetLeafValue(leaf_init_val); + UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), false, std::minus()); + tracker->UpdatePredictions(active_forest.get(), *data); + } +} + +[[cpp11::register]] +void adjust_residual_active_forest_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, bool requires_basis, bool add) { - // Determine whether or not we are adding forest predictions to the residuals - std::function op; - if (add) op = std::plus(); - else op = std::minus(); - - // Perform the update (addition / subtraction) operation - StochTree::UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), requires_basis, op); + // Determine whether or not we are adding forest predictions to the residuals + std::function op; + if (add) + op = std::plus(); + else + op = std::minus(); + + // Perform the update (addition / subtraction) operation + StochTree::UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), requires_basis, op); } [[cpp11::register]] -void propagate_basis_update_active_forest_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer active_forest, +void propagate_basis_update_active_forest_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer active_forest, cpp11::external_pointer tracker) { - // Perform the update (addition / subtraction) operation - StochTree::UpdateResidualNewBasis(*tracker, *data, *residual, active_forest.get()); + // Perform the update (addition / subtraction) operation + StochTree::UpdateResidualNewBasis(*tracker, *data, *residual, active_forest.get()); } [[cpp11::register]] -void reset_active_forest_cpp(cpp11::external_pointer active_forest, - cpp11::external_pointer forest_samples, +void reset_active_forest_cpp(cpp11::external_pointer active_forest, + cpp11::external_pointer forest_samples, int forest_num) { - // Extract raw pointer to the forest held at index forest_num - StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); + // Extract raw pointer to the forest held at index forest_num + StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); - // Reset active forest using the forest held at index forest_num - active_forest->ReconstituteFromForest(*forest); + // Reset active forest using the forest held at index forest_num + active_forest->ReconstituteFromForest(*forest); } [[cpp11::register]] -void reset_forest_model_cpp(cpp11::external_pointer forest_tracker, - cpp11::external_pointer forest, - cpp11::external_pointer data, - cpp11::external_pointer residual, +void reset_forest_model_cpp(cpp11::external_pointer forest_tracker, + cpp11::external_pointer forest, + cpp11::external_pointer data, + cpp11::external_pointer residual, bool is_mean_model) { - // Reset forest tracker using the forest held at index forest_num - forest_tracker->ReconstituteFromForest(*forest, *data, *residual, is_mean_model); + // Reset forest tracker using the forest held at index forest_num + forest_tracker->ReconstituteFromForest(*forest, *data, *residual, is_mean_model); } [[cpp11::register]] void root_reset_active_forest_cpp(cpp11::external_pointer active_forest) { - // Reset active forest to root - active_forest->ResetRoot(); + // Reset active forest to root + active_forest->ResetRoot(); } diff --git a/src/kernel.cpp b/src/kernel.cpp index 38fdd35c..1d3309d6 100644 --- a/src/kernel.cpp +++ b/src/kernel.cpp @@ -8,38 +8,37 @@ typedef Eigen::Map forest_container, int forest_num) { - return forest_container->GetEnsemble(forest_num)->GetMaxLeafIndex() - 1; + return forest_container->GetEnsemble(forest_num)->GetMaxLeafIndex() - 1; } [[cpp11::register]] cpp11::writable::integers_matrix<> compute_leaf_indices_cpp( - cpp11::external_pointer forest_container, - cpp11::doubles_matrix<> covariates, cpp11::integers forest_nums -) { - // Wrap an Eigen Map around the raw data of the covariate matrix - StochTree::data_size_t num_obs = covariates.nrow(); - int num_covariates = covariates.ncol(); - double* covariate_data_ptr = REAL(PROTECT(covariates)); - DoubleMatrixType covariates_eigen(covariate_data_ptr, num_obs, num_covariates); - - // Extract other output dimensions - int num_trees = forest_container->NumTrees(); - int num_samples = forest_nums.size(); - - // Declare outputs - cpp11::writable::integers_matrix<> output_matrix(num_obs*num_trees, num_samples); - - // Wrap Eigen Maps around kernel and kernel inverse matrices - int* output_data_ptr = INTEGER(PROTECT(output_matrix)); - IntMatrixType output_eigen(output_data_ptr, num_obs*num_trees, num_samples); - - // Compute leaf indices - std::vector forest_indices(forest_nums.begin(), forest_nums.end()); - forest_container->PredictLeafIndicesInplace(covariates_eigen, output_eigen, forest_indices, num_trees, num_obs); - - // Unprotect pointers to R data - UNPROTECT(2); - - // Return matrix - return output_matrix; + cpp11::external_pointer forest_container, + cpp11::doubles_matrix<> covariates, cpp11::integers forest_nums) { + // Wrap an Eigen Map around the raw data of the covariate matrix + StochTree::data_size_t num_obs = covariates.nrow(); + int num_covariates = covariates.ncol(); + double* covariate_data_ptr = REAL(PROTECT(covariates)); + DoubleMatrixType covariates_eigen(covariate_data_ptr, num_obs, num_covariates); + + // Extract other output dimensions + int num_trees = forest_container->NumTrees(); + int num_samples = forest_nums.size(); + + // Declare outputs + cpp11::writable::integers_matrix<> output_matrix(num_obs * num_trees, num_samples); + + // Wrap Eigen Maps around kernel and kernel inverse matrices + int* output_data_ptr = INTEGER(PROTECT(output_matrix)); + IntMatrixType output_eigen(output_data_ptr, num_obs * num_trees, num_samples); + + // Compute leaf indices + std::vector forest_indices(forest_nums.begin(), forest_nums.end()); + forest_container->PredictLeafIndicesInplace(covariates_eigen, output_eigen, forest_indices, num_trees, num_obs); + + // Unprotect pointers to R data + UNPROTECT(2); + + // Return matrix + return output_matrix; } diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index c456c29b..5538db9f 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -5,37 +5,31 @@ namespace StochTree { double GaussianConstantLeafModel::SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance) { - double left_log_ml = ( - -0.5*std::log(1 + tau_*(left_stat.sum_w/global_variance)) + ((tau_*left_stat.sum_yw*left_stat.sum_yw)/(2.0*global_variance*(tau_*left_stat.sum_w + global_variance))) - ); + double left_log_ml = (-0.5 * std::log(1 + tau_ * (left_stat.sum_w / global_variance)) + ((tau_ * left_stat.sum_yw * left_stat.sum_yw) / (2.0 * global_variance * (tau_ * left_stat.sum_w + global_variance)))); - double right_log_ml = ( - -0.5*std::log(1 + tau_*(right_stat.sum_w/global_variance)) + ((tau_*right_stat.sum_yw*right_stat.sum_yw)/(2.0*global_variance*(tau_*right_stat.sum_w + global_variance))) - ); + double right_log_ml = (-0.5 * std::log(1 + tau_ * (right_stat.sum_w / global_variance)) + ((tau_ * right_stat.sum_yw * right_stat.sum_yw) / (2.0 * global_variance * (tau_ * right_stat.sum_w + global_variance)))); return left_log_ml + right_log_ml; } double GaussianConstantLeafModel::NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance) { - double log_ml = ( - -0.5*std::log(1 + tau_*(suff_stat.sum_w/global_variance)) + ((tau_*suff_stat.sum_yw*suff_stat.sum_yw)/(2.0*global_variance*(tau_*suff_stat.sum_w + global_variance))) - ); + double log_ml = (-0.5 * std::log(1 + tau_ * (suff_stat.sum_w / global_variance)) + ((tau_ * suff_stat.sum_yw * suff_stat.sum_yw) / (2.0 * global_variance * (tau_ * suff_stat.sum_w + global_variance)))); return log_ml; } double GaussianConstantLeafModel::PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance) { - return (tau_*suff_stat.sum_yw) / (suff_stat.sum_w*tau_ + global_variance); + return (tau_ * suff_stat.sum_yw) / (suff_stat.sum_w * tau_ + global_variance); } double GaussianConstantLeafModel::PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance) { - return (tau_*global_variance) / (suff_stat.sum_w*tau_ + global_variance); + return (tau_ * global_variance) / (suff_stat.sum_w * tau_ + global_variance); } void GaussianConstantLeafModel::SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen) { // Vector of leaf indices for tree std::vector tree_leaves = tree->GetLeaves(); - + // Initialize sufficient statistics GaussianConstantSuffStat node_suff_stat = GaussianConstantSuffStat(); @@ -49,11 +43,11 @@ void GaussianConstantLeafModel::SampleLeafParameters(ForestDataset& dataset, For leaf_id = tree_leaves[i]; node_suff_stat.ResetSuffStat(); AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, leaf_id); - + // Compute posterior mean and variance node_mean = PosteriorParameterMean(node_suff_stat, global_variance); node_variance = PosteriorParameterVariance(node_suff_stat, global_variance); - + // Draw from N(mean, stddev^2) and set the leaf parameter with each draw node_mu = normal_sampler_.Sample(node_mean, node_variance, gen); tree->SetLeaf(leaf_id, node_mu); @@ -69,37 +63,31 @@ void GaussianConstantLeafModel::SetEnsembleRootPredictedValue(ForestDataset& dat } double GaussianUnivariateRegressionLeafModel::SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance) { - double left_log_ml = ( - -0.5*std::log(1 + tau_*(left_stat.sum_xxw/global_variance)) + ((tau_*left_stat.sum_yxw*left_stat.sum_yxw)/(2.0*global_variance*(tau_*left_stat.sum_xxw + global_variance))) - ); + double left_log_ml = (-0.5 * std::log(1 + tau_ * (left_stat.sum_xxw / global_variance)) + ((tau_ * left_stat.sum_yxw * left_stat.sum_yxw) / (2.0 * global_variance * (tau_ * left_stat.sum_xxw + global_variance)))); - double right_log_ml = ( - -0.5*std::log(1 + tau_*(right_stat.sum_xxw/global_variance)) + ((tau_*right_stat.sum_yxw*right_stat.sum_yxw)/(2.0*global_variance*(tau_*right_stat.sum_xxw + global_variance))) - ); + double right_log_ml = (-0.5 * std::log(1 + tau_ * (right_stat.sum_xxw / global_variance)) + ((tau_ * right_stat.sum_yxw * right_stat.sum_yxw) / (2.0 * global_variance * (tau_ * right_stat.sum_xxw + global_variance)))); return left_log_ml + right_log_ml; } double GaussianUnivariateRegressionLeafModel::NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance) { - double log_ml = ( - -0.5*std::log(1 + tau_*(suff_stat.sum_xxw/global_variance)) + ((tau_*suff_stat.sum_yxw*suff_stat.sum_yxw)/(2.0*global_variance*(tau_*suff_stat.sum_xxw + global_variance))) - ); + double log_ml = (-0.5 * std::log(1 + tau_ * (suff_stat.sum_xxw / global_variance)) + ((tau_ * suff_stat.sum_yxw * suff_stat.sum_yxw) / (2.0 * global_variance * (tau_ * suff_stat.sum_xxw + global_variance)))); return log_ml; } double GaussianUnivariateRegressionLeafModel::PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance) { - return (tau_*suff_stat.sum_yxw) / (suff_stat.sum_xxw*tau_ + global_variance); + return (tau_ * suff_stat.sum_yxw) / (suff_stat.sum_xxw * tau_ + global_variance); } double GaussianUnivariateRegressionLeafModel::PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance) { - return (tau_*global_variance) / (suff_stat.sum_xxw*tau_ + global_variance); + return (tau_ * global_variance) / (suff_stat.sum_xxw * tau_ + global_variance); } void GaussianUnivariateRegressionLeafModel::SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen) { // Vector of leaf indices for tree std::vector tree_leaves = tree->GetLeaves(); - + // Initialize sufficient statistics GaussianUnivariateRegressionSuffStat node_suff_stat = GaussianUnivariateRegressionSuffStat(); @@ -113,11 +101,11 @@ void GaussianUnivariateRegressionLeafModel::SampleLeafParameters(ForestDataset& leaf_id = tree_leaves[i]; node_suff_stat.ResetSuffStat(); AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, leaf_id); - + // Compute posterior mean and variance node_mean = PosteriorParameterMean(node_suff_stat, global_variance); node_variance = PosteriorParameterVariance(node_suff_stat, global_variance); - + // Draw from N(mean, stddev^2) and set the leaf parameter with each draw node_mu = normal_sampler_.Sample(node_mean, node_variance, gen); tree->SetLeaf(leaf_id, node_mu); @@ -134,38 +122,32 @@ void GaussianUnivariateRegressionLeafModel::SetEnsembleRootPredictedValue(Forest double GaussianMultivariateRegressionLeafModel::SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance) { Eigen::MatrixXd I_p = Eigen::MatrixXd::Identity(left_stat.p, left_stat.p); - double left_log_ml = ( - -0.5*std::log((I_p + (Sigma_0_ * left_stat.XtWX)/global_variance).determinant()) + 0.5*((left_stat.ytWX/global_variance) * (Sigma_0_.inverse() + (left_stat.XtWX/global_variance)).inverse() * (left_stat.ytWX/global_variance).transpose()).value() - ); + double left_log_ml = (-0.5 * std::log((I_p + (Sigma_0_ * left_stat.XtWX) / global_variance).determinant()) + 0.5 * ((left_stat.ytWX / global_variance) * (Sigma_0_.inverse() + (left_stat.XtWX / global_variance)).inverse() * (left_stat.ytWX / global_variance).transpose()).value()); - double right_log_ml = ( - -0.5*std::log((I_p + (Sigma_0_ * right_stat.XtWX)/global_variance).determinant()) + 0.5*((right_stat.ytWX/global_variance) * (Sigma_0_.inverse() + (right_stat.XtWX/global_variance)).inverse() * (right_stat.ytWX/global_variance).transpose()).value() - ); + double right_log_ml = (-0.5 * std::log((I_p + (Sigma_0_ * right_stat.XtWX) / global_variance).determinant()) + 0.5 * ((right_stat.ytWX / global_variance) * (Sigma_0_.inverse() + (right_stat.XtWX / global_variance)).inverse() * (right_stat.ytWX / global_variance).transpose()).value()); return left_log_ml + right_log_ml; } double GaussianMultivariateRegressionLeafModel::NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance) { Eigen::MatrixXd I_p = Eigen::MatrixXd::Identity(suff_stat.p, suff_stat.p); - double log_ml = ( - -0.5*std::log((I_p + (Sigma_0_ * suff_stat.XtWX)/global_variance).determinant()) + 0.5*((suff_stat.ytWX/global_variance) * (Sigma_0_.inverse() + (suff_stat.XtWX/global_variance)).inverse() * (suff_stat.ytWX/global_variance).transpose()).value() - ); + double log_ml = (-0.5 * std::log((I_p + (Sigma_0_ * suff_stat.XtWX) / global_variance).determinant()) + 0.5 * ((suff_stat.ytWX / global_variance) * (Sigma_0_.inverse() + (suff_stat.XtWX / global_variance)).inverse() * (suff_stat.ytWX / global_variance).transpose()).value()); return log_ml; } Eigen::VectorXd GaussianMultivariateRegressionLeafModel::PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance) { - return (Sigma_0_.inverse() + (suff_stat.XtWX/global_variance)).inverse() * (suff_stat.ytWX/global_variance).transpose(); + return (Sigma_0_.inverse() + (suff_stat.XtWX / global_variance)).inverse() * (suff_stat.ytWX / global_variance).transpose(); } Eigen::MatrixXd GaussianMultivariateRegressionLeafModel::PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance) { - return (Sigma_0_.inverse() + (suff_stat.XtWX/global_variance)).inverse(); + return (Sigma_0_.inverse() + (suff_stat.XtWX / global_variance)).inverse(); } void GaussianMultivariateRegressionLeafModel::SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen) { // Vector of leaf indices for tree std::vector tree_leaves = tree->GetLeaves(); - + // Initialize sufficient statistics int num_basis = dataset.GetBasis().cols(); GaussianMultivariateRegressionSuffStat node_suff_stat = GaussianMultivariateRegressionSuffStat(num_basis); @@ -180,11 +162,11 @@ void GaussianMultivariateRegressionLeafModel::SampleLeafParameters(ForestDataset leaf_id = tree_leaves[i]; node_suff_stat.ResetSuffStat(); AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, leaf_id); - + // Compute posterior mean and variance node_mean = PosteriorParameterMean(node_suff_stat, global_variance); node_variance = PosteriorParameterVariance(node_suff_stat, global_variance); - + // Draw from N(mean, stddev^2) and set the leaf parameter with each draw node_mu = multivariate_normal_sampler_.Sample(node_mean, node_variance, gen); tree->SetLeafVector(leaf_id, node_mu); @@ -194,7 +176,7 @@ void GaussianMultivariateRegressionLeafModel::SampleLeafParameters(ForestDataset void GaussianMultivariateRegressionLeafModel::SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value) { int num_trees = ensemble->NumTrees(); int num_basis = dataset.GetBasis().cols(); - + // Check that root predicted value is close to 0 // TODO: formalize and document this if ((root_pred_value < -0.1) || root_pred_value > 0.1) { @@ -240,7 +222,7 @@ double LogLinearVarianceLeafModel::PosteriorParameterScale(LogLinearVarianceSuff void LogLinearVarianceLeafModel::SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen) { // Vector of leaf indices for tree std::vector tree_leaves = tree->GetLeaves(); - + // Initialize sufficient statistics LogLinearVarianceSuffStat node_suff_stat = LogLinearVarianceSuffStat(); @@ -254,11 +236,11 @@ void LogLinearVarianceLeafModel::SampleLeafParameters(ForestDataset& dataset, Fo leaf_id = tree_leaves[i]; node_suff_stat.ResetSuffStat(); AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, leaf_id); - + // Compute posterior mean and variance node_shape = PosteriorParameterShape(node_suff_stat, global_variance); node_rate = PosteriorParameterScale(node_suff_stat, global_variance); - + // Draw from IG(shape, scale) and set the leaf parameter with each draw node_mu = -std::log(sample_gamma(gen, node_shape, 1.) / node_rate); // node_mu = std::log(gamma_sampler_.Sample(node_shape, node_rate, gen, true)); @@ -306,7 +288,7 @@ double CloglogOrdinalLeafModel::PosteriorParameterRate(CloglogOrdinalSuffStat& s void CloglogOrdinalLeafModel::SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen) { // Vector of leaf indices for tree std::vector tree_leaves = tree->GetLeaves(); - + // Initialize sufficient statistics CloglogOrdinalSuffStat node_suff_stat = CloglogOrdinalSuffStat(); @@ -320,7 +302,7 @@ void CloglogOrdinalLeafModel::SampleLeafParameters(ForestDataset& dataset, Fores leaf_id = tree_leaves[i]; node_suff_stat.ResetSuffStat(); AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, leaf_id); - + // Compute posterior shape and rate node_shape = PosteriorParameterShape(node_suff_stat, global_variance); node_rate = PosteriorParameterRate(node_suff_stat, global_variance); @@ -334,4 +316,4 @@ void CloglogOrdinalLeafModel::SampleLeafParameters(ForestDataset& dataset, Fores } } -} // namespace StochTree +} // namespace StochTree diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp index 81841bde..dcb62ab6 100644 --- a/src/ordinal_sampler.cpp +++ b/src/ordinal_sampler.cpp @@ -17,9 +17,9 @@ double OrdinalSampler::SampleTruncatedExponential(std::mt19937& gen, double rate void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, std::mt19937& gen) { // Get auxiliary data vectors - const std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // gamma cutpoints + const std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // gamma cutpoints const std::vector& lambda_hat = dataset.GetAuxiliaryDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) - std::vector& Z = dataset.GetAuxiliaryDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) + std::vector& Z = dataset.GetAuxiliaryDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) int K = gamma.size() + 1; // Number of ordinal categories int N = dataset.NumObservations(); @@ -41,12 +41,12 @@ void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::Vector } } -void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, - double alpha_gamma, double beta_gamma, +void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, + double alpha_gamma, double beta_gamma, double gamma_0, std::mt19937& gen) { // Get auxiliary data vectors - std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // cutpoints gamma_k's - const std::vector& Z = dataset.GetAuxiliaryDataVector(0); // latent variables z_i's + std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // cutpoints gamma_k's + const std::vector& Z = dataset.GetAuxiliaryDataVector(0); // latent variables z_i's const std::vector& lambda_hat = dataset.GetAuxiliaryDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) int K = gamma.size() + 1; // Number of ordinal categories @@ -78,23 +78,23 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability // if (K > 2) { - gamma[0] = gamma_0; + gamma[0] = gamma_0; // } } void OrdinalSampler::UpdateCumulativeExpSums(ForestDataset& dataset) { // Get auxiliary data vectors const std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // cutpoints gamma_k's - std::vector& seg = dataset.GetAuxiliaryDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) + std::vector& seg = dataset.GetAuxiliaryDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) // Update seg (sum of exponentials of gamma cutpoints) for (int j = 0; j < static_cast(seg.size()); j++) { if (j == 0) { - seg[j] = 0.0; // checked and it is correct + seg[j] = 0.0; // checked and it is correct } else { seg[j] = seg[j - 1] + std::exp(gamma[j - 1]); // checked and it is correct } } } -} // namespace StochTree +} // namespace StochTree diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 73b37fe8..65f339e6 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -36,7 +36,6 @@ void ForestTracker::ReconstituteFromForest(TreeEnsemble& forest, ForestDataset& // Reconstitute each of the remaining data structures in the tracker based on splits in the ensemble // UnsortedNodeSampleTracker unsorted_node_sample_tracker_->ReconstituteFromForest(forest, dataset); - } void ForestTracker::ResetRoot(Eigen::MatrixXd& covariates, std::vector& feature_types, int32_t tree_num) { @@ -45,19 +44,19 @@ void ForestTracker::ResetRoot(Eigen::MatrixXd& covariates, std::vectorGetNodeId(observation_num, tree_num);} +data_size_t ForestTracker::GetNodeId(int observation_num, int tree_num) { return sample_node_mapper_->GetNodeId(observation_num, tree_num); } -data_size_t ForestTracker::UnsortedNodeBegin(int tree_id, int node_id) {return unsorted_node_sample_tracker_->NodeBegin(tree_id, node_id);} +data_size_t ForestTracker::UnsortedNodeBegin(int tree_id, int node_id) { return unsorted_node_sample_tracker_->NodeBegin(tree_id, node_id); } -data_size_t ForestTracker::UnsortedNodeEnd(int tree_id, int node_id) {return unsorted_node_sample_tracker_->NodeEnd(tree_id, node_id);} +data_size_t ForestTracker::UnsortedNodeEnd(int tree_id, int node_id) { return unsorted_node_sample_tracker_->NodeEnd(tree_id, node_id); } -data_size_t ForestTracker::UnsortedNodeSize(int tree_id, int node_id) {return unsorted_node_sample_tracker_->NodeSize(tree_id, node_id);} +data_size_t ForestTracker::UnsortedNodeSize(int tree_id, int node_id) { return unsorted_node_sample_tracker_->NodeSize(tree_id, node_id); } -data_size_t ForestTracker::SortedNodeBegin(int node_id, int feature_id) {return sorted_node_sample_tracker_->NodeBegin(node_id, feature_id);} +data_size_t ForestTracker::SortedNodeBegin(int node_id, int feature_id) { return sorted_node_sample_tracker_->NodeBegin(node_id, feature_id); } -data_size_t ForestTracker::SortedNodeEnd(int node_id, int feature_id) {return sorted_node_sample_tracker_->NodeEnd(node_id, feature_id);} +data_size_t ForestTracker::SortedNodeEnd(int node_id, int feature_id) { return sorted_node_sample_tracker_->NodeEnd(node_id, feature_id); } -data_size_t ForestTracker::SortedNodeSize(int node_id, int feature_id) {return sorted_node_sample_tracker_->NodeSize(node_id, feature_id);} +data_size_t ForestTracker::SortedNodeSize(int node_id, int feature_id) { return sorted_node_sample_tracker_->NodeSize(node_id, feature_id); } std::vector::iterator ForestTracker::UnsortedNodeBeginIterator(int tree_id, int node_id) { return unsorted_node_sample_tracker_->NodeBeginIterator(tree_id, node_id); @@ -84,7 +83,7 @@ void ForestTracker::AssignAllSamplesToRoot(int32_t tree_num) { void ForestTracker::AssignAllSamplesToConstantPrediction(double value) { for (data_size_t i = 0; i < num_observations_; i++) { - sum_predictions_[i] = value*num_trees_; + sum_predictions_[i] = value * num_trees_; } for (int i = 0; i < num_trees_; i++) { sample_pred_mapper_->AssignAllSamplesToConstantPrediction(i, value); @@ -565,8 +564,7 @@ bool FeatureUnsortedPartition::IsValidNode(int node_id) { if (node_id >= num_nodes_ || node_id < 0) { return false; } - return !(std::find(deleted_nodes_.begin(), deleted_nodes_.end(), node_id) - != deleted_nodes_.end()); + return !(std::find(deleted_nodes_.begin(), deleted_nodes_.end(), node_id) != deleted_nodes_.end()); } bool FeatureUnsortedPartition::LeftNodeIsLeaf(int node_id) { diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 5b7ff265..020bfbaf 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -34,7 +34,7 @@ class ForestDatasetCpp { void AddCovariates(py::array_t covariate_matrix, data_size_t num_row, int num_col, bool row_major) { // Extract pointer to contiguous block of memory double* data_ptr = static_cast(covariate_matrix.mutable_data()); - + // Load covariates dataset_->AddCovariates(data_ptr, num_row, num_col, row_major); } @@ -42,7 +42,7 @@ class ForestDatasetCpp { void AddBasis(py::array_t basis_matrix, data_size_t num_row, int num_col, bool row_major) { // Extract pointer to contiguous block of memory double* data_ptr = static_cast(basis_matrix.mutable_data()); - + // Load covariates dataset_->AddBasis(data_ptr, num_row, num_col, row_major); } @@ -50,7 +50,7 @@ class ForestDatasetCpp { void UpdateBasis(py::array_t basis_matrix, data_size_t num_row, int num_col, bool row_major) { // Extract pointer to contiguous block of memory double* data_ptr = static_cast(basis_matrix.mutable_data()); - + // Load covariates dataset_->UpdateBasis(data_ptr, num_row, num_col, row_major); } @@ -58,7 +58,7 @@ class ForestDatasetCpp { void AddVarianceWeights(py::array_t weight_vector, data_size_t num_row) { // Extract pointer to contiguous block of memory double* data_ptr = static_cast(weight_vector.mutable_data()); - + // Load covariates dataset_->AddVarianceWeights(data_ptr, num_row); } @@ -66,7 +66,7 @@ class ForestDatasetCpp { void UpdateVarianceWeights(py::array_t weight_vector, data_size_t num_row, bool exponentiate) { // Extract pointer to contiguous block of memory double* data_ptr = static_cast(weight_vector.mutable_data()); - + // Load covariates dataset_->UpdateVarWeights(data_ptr, num_row, exponentiate); } @@ -79,7 +79,7 @@ class ForestDatasetCpp { auto accessor = result.mutable_unchecked<2>(); for (size_t i = 0; i < n; i++) { for (int j = 0; j < num_covariates; j++) { - accessor(i,j) = dataset_->CovariateValue(i,j); + accessor(i, j) = dataset_->CovariateValue(i, j); } } @@ -94,7 +94,7 @@ class ForestDatasetCpp { auto accessor = result.mutable_unchecked<2>(); for (size_t i = 0; i < n; i++) { for (int j = 0; j < num_basis; j++) { - accessor(i,j) = dataset_->BasisValue(i,j); + accessor(i, j) = dataset_->BasisValue(i, j); } } @@ -169,7 +169,7 @@ class ResidualCpp { ResidualCpp(py::array_t residual_array, data_size_t num_row) { // Extract pointer to contiguous block of memory double* data_ptr = static_cast(residual_array.mutable_data()); - + // Initialize pointer to C++ ColumnVector class residual_ = std::make_unique(data_ptr, num_row); } @@ -182,13 +182,13 @@ class ResidualCpp { py::array_t GetResidualArray() { // Obtain a reference to the underlying Eigen::VectorXd Eigen::VectorXd& resid_vector = residual_->GetData(); - + // Initialize n x 1 numpy array to store the residual data_size_t n = residual_->NumRows(); auto result = py::array_t(py::detail::any_container({n, 1})); auto accessor = result.mutable_unchecked<2>(); for (size_t i = 0; i < n; i++) { - accessor(i,0) = resid_vector(i); + accessor(i, 0) = resid_vector(i); } return result; @@ -311,7 +311,7 @@ class ForestContainerCpp { for (size_t i = 0; i < n; i++) { for (int j = 0; j < num_samples; j++) { // NOTE: converting from "column-major" to "row-major" here - accessor(i,j) = output_raw[j*n + i]; + accessor(i, j) = output_raw[j * n + i]; // ptr[i*num_samples + j] = output_raw[j*n + i]; } } @@ -335,7 +335,7 @@ class ForestContainerCpp { for (size_t i = 0; i < n; i++) { for (int j = 0; j < output_dim; j++) { for (int k = 0; k < num_samples; k++) { - accessor(i,k,j) = output_raw[k*(output_dim*n) + i*output_dim + j]; + accessor(i, k, j) = output_raw[k * (output_dim * n) + i * output_dim + j]; // ptr[i*(output_dim*num_samples) + j*output_dim + k] = output_raw[k*(output_dim*n) + i*output_dim + j]; } } @@ -359,7 +359,7 @@ class ForestContainerCpp { // double *ptr = static_cast(buf.ptr); for (size_t i = 0; i < n; i++) { for (int j = 0; j < output_dim; j++) { - accessor(i,j) = output_raw[i*output_dim + j]; + accessor(i, j) = output_raw[i * output_dim + j]; // ptr[i*output_dim + j] = output_raw[i*output_dim + j]; } } @@ -382,7 +382,7 @@ class ForestContainerCpp { // double *ptr = static_cast(buf.ptr); for (size_t i = 0; i < n; i++) { for (int j = 0; j < output_dim; j++) { - accessor(i,j) = output_raw[i*output_dim + j]; + accessor(i, j) = output_raw[i * output_dim + j]; // ptr[i*output_dim + j] = output_raw[i*output_dim + j]; } } @@ -397,7 +397,7 @@ class ForestContainerCpp { void SetRootVector(int forest_num, py::array_t& leaf_vector, int leaf_size) { std::vector leaf_vector_converted(leaf_size); for (int i = 0; i < leaf_size; i++) { - leaf_vector_converted[i] = leaf_vector.at(i); + leaf_vector_converted[i] = leaf_vector.at(i); } forest_samples_->InitializeRoot(leaf_vector_converted); } @@ -453,8 +453,8 @@ class ForestContainerCpp { StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(num_samples); int num_trees = ensemble->NumTrees(); for (int i = 0; i < num_trees; i++) { - StochTree::Tree* tree = ensemble->GetTree(i); - tree->SetLeaf(0, leaf_value); + StochTree::Tree* tree = ensemble->GetTree(i); + tree->SetLeaf(0, leaf_value); } } @@ -469,13 +469,13 @@ class ForestContainerCpp { StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(num_samples); int num_trees = ensemble->NumTrees(); for (int i = 0; i < num_trees; i++) { - StochTree::Tree* tree = ensemble->GetTree(i); - tree->SetLeafVector(0, leaf_vector_cast); + StochTree::Tree* tree = ensemble->GetTree(i); + tree->SetLeafVector(0, leaf_vector_cast); } } - void AddNumericSplitVector(int forest_num, int tree_num, int leaf_num, int feature_num, - double split_threshold, py::array_t left_leaf_vector, + void AddNumericSplitVector(int forest_num, int tree_num, int leaf_num, int feature_num, + double split_threshold, py::array_t left_leaf_vector, py::array_t right_leaf_vector) { if (forest_samples_->OutputDimension() != left_leaf_vector.size()) { StochTree::Log::Fatal("left_leaf_vector must match forest leaf dimension"); @@ -495,8 +495,8 @@ class ForestContainerCpp { tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_vector_cast, right_leaf_vector_cast); } - void AddNumericSplitValue(int forest_num, int tree_num, int leaf_num, int feature_num, - double split_threshold, double left_leaf_value, double right_leaf_value) { + void AddNumericSplitValue(int forest_num, int tree_num, int leaf_num, int feature_num, + double split_threshold, double left_leaf_value, double right_leaf_value) { if (forest_samples_->OutputDimension() != 1) { StochTree::Log::Fatal("left_leaf_value must match forest leaf dimension"); } @@ -534,9 +534,9 @@ class ForestContainerCpp { StochTree::Tree* tree = ensemble->GetTree(tree_num); std::vector split_nodes = tree->GetInternalNodes(); for (int i = 0; i < split_nodes.size(); i++) { - auto node_id = split_nodes.at(i); - auto split_feature = tree->SplitIndex(node_id); - accessor(split_feature)++; + auto node_id = split_nodes.at(i); + auto split_feature = tree->SplitIndex(node_id); + accessor(split_feature)++; } return result; } @@ -587,12 +587,12 @@ class ForestContainerCpp { py::array_t GetGranularSplitCounts(int num_features) { int num_samples = forest_samples_->NumSamples(); int num_trees = forest_samples_->NumTrees(); - auto result = py::array_t(py::detail::any_container({num_samples,num_trees,num_features})); + auto result = py::array_t(py::detail::any_container({num_samples, num_trees, num_features})); auto accessor = result.mutable_unchecked<3>(); for (int i = 0; i < num_samples; i++) { for (int j = 0; j < num_trees; j++) { for (int k = 0; k < num_features; k++) { - accessor(i,j,k) = 0; + accessor(i, j, k) = 0; } } } @@ -604,7 +604,7 @@ class ForestContainerCpp { for (int k = 0; k < split_nodes.size(); k++) { auto node_id = split_nodes.at(k); auto split_feature = tree->SplitIndex(node_id); - accessor(i,j,split_feature)++; + accessor(i, j, split_feature)++; } } } @@ -760,7 +760,7 @@ class ForestCpp { } ~ForestCpp() {} - StochTree::TreeEnsemble* GetForestPtr() {return forest_.get();} + StochTree::TreeEnsemble* GetForestPtr() { return forest_.get(); } void MergeForest(ForestCpp& outbound_forest) { forest_->MergeForest(*outbound_forest.GetForestPtr()); @@ -813,7 +813,7 @@ class ForestCpp { auto result = py::array_t(py::detail::any_container({n})); auto accessor = result.mutable_unchecked<1>(); for (size_t i = 0; i < n; i++) { - accessor(i) = output_raw[i]; + accessor(i) = output_raw[i]; } return result; @@ -831,7 +831,7 @@ class ForestCpp { auto accessor = result.mutable_unchecked<2>(); for (size_t i = 0; i < n; i++) { for (int j = 0; j < output_dim; j++) { - accessor(i,j) = output_raw[i*output_dim + j]; + accessor(i, j) = output_raw[i * output_dim + j]; } } @@ -845,7 +845,7 @@ class ForestCpp { void SetRootVector(py::array_t& leaf_vector, int leaf_size) { std::vector leaf_vector_converted(leaf_size); for (int i = 0; i < leaf_size; i++) { - leaf_vector_converted[i] = leaf_vector.at(i); + leaf_vector_converted[i] = leaf_vector.at(i); } forest_->SetLeafVector(leaf_vector_converted); } @@ -856,7 +856,7 @@ class ForestCpp { return forest_.get(); } - void AddNumericSplitValue(int tree_num, int leaf_num, int feature_num, double split_threshold, + void AddNumericSplitValue(int tree_num, int leaf_num, int feature_num, double split_threshold, double left_leaf_value, double right_leaf_value) { if (forest_->OutputDimension() != 1) { StochTree::Log::Fatal("left_leaf_value must match forest leaf dimension"); @@ -872,7 +872,7 @@ class ForestCpp { tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value); } - void AddNumericSplitVector(int tree_num, int leaf_num, int feature_num, double split_threshold, + void AddNumericSplitVector(int tree_num, int leaf_num, int feature_num, double split_threshold, py::array_t left_leaf_vector, py::array_t right_leaf_vector) { if (forest_->OutputDimension() != left_leaf_vector.size()) { StochTree::Log::Fatal("left_leaf_vector must match forest leaf dimension"); @@ -913,9 +913,9 @@ class ForestCpp { StochTree::Tree* tree = forest_->GetTree(tree_num); std::vector split_nodes = tree->GetInternalNodes(); for (int i = 0; i < split_nodes.size(); i++) { - auto node_id = split_nodes.at(i); - auto split_feature = tree->SplitIndex(node_id); - accessor(split_feature)++; + auto node_id = split_nodes.at(i); + auto split_feature = tree->SplitIndex(node_id); + accessor(split_feature)++; } return result; } @@ -941,11 +941,11 @@ class ForestCpp { py::array_t GetGranularSplitCounts(int num_features) { int num_trees = forest_->NumTrees(); - auto result = py::array_t(py::detail::any_container({num_trees,num_features})); + auto result = py::array_t(py::detail::any_container({num_trees, num_features})); auto accessor = result.mutable_unchecked<2>(); for (int i = 0; i < num_trees; i++) { for (int j = 0; j < num_features; j++) { - accessor(i,j) = 0; + accessor(i, j) = 0; } } for (int i = 0; i < num_trees; i++) { @@ -954,7 +954,7 @@ class ForestCpp { for (int j = 0; j < split_nodes.size(); j++) { auto node_id = split_nodes.at(i); auto split_feature = tree->SplitIndex(node_id); - accessor(i,split_feature)++; + accessor(i, split_feature)++; } } return result; @@ -1086,9 +1086,9 @@ class ForestSamplerCpp { // Convert vector of integers to std::vector of enum FeatureType std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { - feature_types_[i] = static_cast(feature_types.at(i)); + feature_types_[i] = static_cast(feature_types.at(i)); } - + // Initialize pointer to C++ ForestTracker and TreePrior classes StochTree::ForestDataset* dataset_ptr = dataset.GetDataset(); tracker_ = std::make_unique(dataset_ptr->GetCovariates(), feature_types_, num_trees, num_obs); @@ -1096,75 +1096,80 @@ class ForestSamplerCpp { } ~ForestSamplerCpp() {} - StochTree::ForestTracker* GetTracker() {return tracker_.get();} + StochTree::ForestTracker* GetTracker() { return tracker_.get(); } void ReconstituteTrackerFromForest(ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, bool is_mean_model) { // Extract raw pointer to the forest and dataset StochTree::TreeEnsemble* forest_ptr = forest.GetEnsemble(); StochTree::ForestDataset* data_ptr = dataset.GetDataset(); StochTree::ColumnVector* residual_ptr = residual.GetData(); - + // Reset forest tracker using the forest held at index forest_num tracker_->ReconstituteFromForest(*forest_ptr, *data_ptr, *residual_ptr, is_mean_model); } - void SampleOneIteration(ForestContainerCpp& forest_samples, ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng, - py::array_t feature_types, py::array_t sweep_update_indices, int cutpoint_grid_size, py::array_t leaf_model_scale_input, - py::array_t variable_weights, double a_forest, double b_forest, double global_variance, + void SampleOneIteration(ForestContainerCpp& forest_samples, ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng, + py::array_t feature_types, py::array_t sweep_update_indices, int cutpoint_grid_size, py::array_t leaf_model_scale_input, + py::array_t variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, int num_features_subsample, bool keep_forest = true, bool gfr = true, int num_threads = -1) { // Refactoring completely out of the Python interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; - + // Unpack feature types std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types.at(i)); } - + // Unpack sweep indices std::vector sweep_update_indices_; if (sweep_update_indices.size() > 0) { - sweep_update_indices_.resize(sweep_update_indices.size()); - for (int i = 0; i < sweep_update_indices.size(); i++) { - sweep_update_indices_[i] = sweep_update_indices.at(i); - } + sweep_update_indices_.resize(sweep_update_indices.size()); + for (int i = 0; i < sweep_update_indices.size(); i++) { + sweep_update_indices_[i] = sweep_update_indices.at(i); + } } // Convert leaf model type to enum StochTree::ModelType model_type; - if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; - else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; - else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; - else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; - else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; + if (leaf_model_int == 0) + model_type = StochTree::ModelType::kConstantLeafGaussian; + else if (leaf_model_int == 1) + model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; + else if (leaf_model_int == 2) + model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; + else if (leaf_model_int == 3) + model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) + model_type = StochTree::ModelType::kCloglogOrdinal; // Unpack leaf model parameters double leaf_scale; Eigen::MatrixXd leaf_scale_matrix; if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian)) { - leaf_scale = leaf_model_scale_input.at(0,0); + leaf_scale = leaf_model_scale_input.at(0, 0); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - int num_row = leaf_model_scale_input.shape(0); - int num_col = leaf_model_scale_input.shape(1); - leaf_scale_matrix.resize(num_row, num_col); - for (int i = 0; i < num_row; i++) { - for (int j = 0; j < num_col; j++) { - leaf_scale_matrix(i,j) = leaf_model_scale_input.at(i,j); - } + int num_row = leaf_model_scale_input.shape(0); + int num_col = leaf_model_scale_input.shape(1); + leaf_scale_matrix.resize(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + leaf_scale_matrix(i, j) = leaf_model_scale_input.at(i, j); } + } } // Convert variable weights to std::vector std::vector var_weights_vector(variable_weights.size()); for (int i = 0; i < variable_weights.size(); i++) { - var_weights_vector[i] = variable_weights.at(i); + var_weights_vector[i] = variable_weights.at(i); } // Prepare the samplers StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); - + // Run one iteration of the sampler StochTree::ForestContainer* forest_sample_ptr = forest_samples.GetContainer(); StochTree::TreeEnsemble* active_forest_ptr = forest.GetEnsemble(); @@ -1199,16 +1204,22 @@ class ForestSamplerCpp { } } - void InitializeForestModel(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest, + void InitializeForestModel(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest, int leaf_model_int, py::array_t initial_values) { // Convert leaf model type to enum StochTree::ModelType model_type; - if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; - else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; - else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; - else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; - else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; - else StochTree::Log::Fatal("Invalid model type"); + if (leaf_model_int == 0) + model_type = StochTree::ModelType::kConstantLeafGaussian; + else if (leaf_model_int == 1) + model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; + else if (leaf_model_int == 2) + model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; + else if (leaf_model_int == 3) + model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) + model_type = StochTree::ModelType::kCloglogOrdinal; + else + StochTree::Log::Fatal("Invalid model type"); // Unpack initial value StochTree::TreeEnsemble* forest_ptr = forest.GetEnsemble(); @@ -1221,43 +1232,43 @@ class ForestSamplerCpp { (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) || (model_type == StochTree::ModelType::kLogLinearVariance) || (model_type == StochTree::ModelType::kCloglogOrdinal)) { - init_val = initial_values.at(0); + init_val = initial_values.at(0); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - int leaf_dim = initial_values.size(); - init_value_vector.resize(leaf_dim); - for (int i = 0; i < leaf_dim; i++) { - init_value_vector[i] = initial_values.at(i) / static_cast(num_trees); - } + int leaf_dim = initial_values.size(); + init_value_vector.resize(leaf_dim); + for (int i = 0; i < leaf_dim; i++) { + init_value_vector[i] = initial_values.at(i) / static_cast(num_trees); + } } - + // Initialize the models accordingly double leaf_init_val; if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - leaf_init_val = init_val / static_cast(num_trees); - forest_ptr->SetLeafValue(leaf_init_val); - StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, false, std::minus()); - tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); + leaf_init_val = init_val / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, false, std::minus()); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - leaf_init_val = init_val / static_cast(num_trees); - forest_ptr->SetLeafValue(leaf_init_val); - StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, true, std::minus()); - tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); + leaf_init_val = init_val / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, true, std::minus()); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - forest_ptr->SetLeafVector(init_value_vector); - StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, true, std::minus()); - tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); + forest_ptr->SetLeafVector(init_value_vector); + StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, true, std::minus()); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - leaf_init_val = std::log(init_val) / static_cast(num_trees); - forest_ptr->SetLeafValue(leaf_init_val); - tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); - int n = forest_data_ptr->NumObservations(); - std::vector initial_preds(n, init_val); - forest_data_ptr->AddVarianceWeights(initial_preds.data(), n); + leaf_init_val = std::log(init_val) / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); + int n = forest_data_ptr->NumObservations(); + std::vector initial_preds(n, init_val); + forest_data_ptr->AddVarianceWeights(initial_preds.data(), n); } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { - leaf_init_val = init_val / static_cast(num_trees); - forest_ptr->SetLeafValue(leaf_init_val); - StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, false, std::minus()); - tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); + leaf_init_val = init_val / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, false, std::minus()); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } } @@ -1331,7 +1342,7 @@ class GlobalVarianceModelCpp { StochTree::ColumnVector* residual_ptr = residual.GetData(); std::mt19937* rng_ptr = rng.GetRng(); return var_model_.SampleVarianceParameter(residual_ptr->GetData(), a, b, *rng_ptr); - } + } private: StochTree::GlobalHomoskedasticVarianceModel var_model_; @@ -1391,7 +1402,7 @@ class OrdinalSamplerCpp { class RandomEffectsDatasetCpp { public: - RandomEffectsDatasetCpp() { + RandomEffectsDatasetCpp() { rfx_dataset_ = std::make_unique(); } ~RandomEffectsDatasetCpp() {} @@ -1443,7 +1454,7 @@ class RandomEffectsDatasetCpp { auto accessor = result.mutable_unchecked<2>(); for (py::ssize_t i = 0; i < num_row; i++) { for (int j = 0; j < num_col; j++) { - accessor(i,j) = rfx_dataset_->BasisValue(i,j); + accessor(i, j) = rfx_dataset_->BasisValue(i, j); } } return result; @@ -1466,9 +1477,9 @@ class RandomEffectsDatasetCpp { } return result; } - bool HasGroupLabels() {return rfx_dataset_->HasGroupLabels();} - bool HasBasis() {return rfx_dataset_->HasBasis();} - bool HasVarianceWeights() {return rfx_dataset_->HasVarWeights();} + bool HasGroupLabels() { return rfx_dataset_->HasGroupLabels(); } + bool HasBasis() { return rfx_dataset_->HasBasis(); } + bool HasVarianceWeights() { return rfx_dataset_->HasVarWeights(); } private: std::unique_ptr rfx_dataset_; @@ -1508,7 +1519,7 @@ class RandomEffectsContainerCpp { for (int i = 0; i < num_components; i++) { for (int j = 0; j < num_groups; j++) { for (int k = 0; k < num_samples; k++) { - accessor(i,j,k) = beta_raw[k*num_groups*num_components + j*num_components + i]; + accessor(i, j, k) = beta_raw[k * num_groups * num_components + j * num_components + i]; } } } @@ -1524,7 +1535,7 @@ class RandomEffectsContainerCpp { for (int i = 0; i < num_components; i++) { for (int j = 0; j < num_groups; j++) { for (int k = 0; k < num_samples; k++) { - accessor(i,j,k) = xi_raw[k*num_groups*num_components + j*num_components + i]; + accessor(i, j, k) = xi_raw[k * num_groups * num_components + j * num_components + i]; } } } @@ -1538,7 +1549,7 @@ class RandomEffectsContainerCpp { auto accessor = result.mutable_unchecked<2>(); for (int i = 0; i < num_components; i++) { for (int j = 0; j < num_samples; j++) { - accessor(i,j) = alpha_raw[j*num_components + i]; + accessor(i, j) = alpha_raw[j * num_components + i]; } } return result; @@ -1551,7 +1562,7 @@ class RandomEffectsContainerCpp { auto accessor = result.mutable_unchecked<2>(); for (int i = 0; i < num_components; i++) { for (int j = 0; j < num_samples; j++) { - accessor(i,j) = sigma_raw[j*num_components + i]; + accessor(i, j) = sigma_raw[j * num_components + i]; } } return result; @@ -1577,7 +1588,7 @@ class RandomEffectsContainerCpp { StochTree::RandomEffectsContainer* GetRandomEffectsContainer() { return rfx_container_.get(); } - + private: std::unique_ptr rfx_container_; }; @@ -1665,8 +1676,8 @@ class RandomEffectsModelCpp { StochTree::MultivariateRegressionRandomEffectsModel* GetModel() { return rfx_model_.get(); } - void SampleRandomEffects(RandomEffectsDatasetCpp& rfx_dataset, ResidualCpp& residual, - RandomEffectsTrackerCpp& rfx_tracker, RandomEffectsContainerCpp& rfx_container, + void SampleRandomEffects(RandomEffectsDatasetCpp& rfx_dataset, ResidualCpp& residual, + RandomEffectsTrackerCpp& rfx_tracker, RandomEffectsContainerCpp& rfx_container, bool keep_sample, double global_variance, RngCpp& rng); py::array_t Predict(RandomEffectsDatasetCpp& rfx_dataset, RandomEffectsTrackerCpp& rfx_tracker) { std::vector output = rfx_model_->Predict(*rfx_dataset.GetDataset(), *rfx_tracker.GetTracker()); @@ -1691,7 +1702,7 @@ class RandomEffectsModelCpp { Eigen::MatrixXd group_params_eigen(nrow, ncol); for (py::ssize_t i = 0; i < nrow; i++) { for (int j = 0; j < ncol; j++) { - group_params_eigen(i,j) = group_params.at(i,j); + group_params_eigen(i, j) = group_params.at(i, j); } } rfx_model_->SetGroupParameters(group_params_eigen); @@ -1702,7 +1713,7 @@ class RandomEffectsModelCpp { Eigen::MatrixXd working_param_cov_eigen(nrow, ncol); for (int i = 0; i < nrow; i++) { for (int j = 0; j < ncol; j++) { - working_param_cov_eigen(i,j) = working_param_cov.at(i,j); + working_param_cov_eigen(i, j) = working_param_cov.at(i, j); } } rfx_model_->SetWorkingParameterCovariance(working_param_cov_eigen); @@ -1713,7 +1724,7 @@ class RandomEffectsModelCpp { Eigen::MatrixXd group_param_cov_eigen(nrow, ncol); for (int i = 0; i < nrow; i++) { for (int j = 0; j < ncol; j++) { - group_param_cov_eigen(i,j) = group_param_cov.at(i,j); + group_param_cov_eigen(i, j) = group_param_cov.at(i, j); } } rfx_model_->SetGroupParameterCovariance(group_param_cov_eigen); @@ -1795,7 +1806,7 @@ class JsonCpp { nlohmann::json groupids_json = nlohmann::json::array(); for (int i = 0; i < rfx_group_ids.size(); i++) { groupids_json.emplace_back(rfx_group_ids.at(i)); - } + } json_->at("random_effects").emplace(rfx_label, groupids_json); return rfx_label; } @@ -2153,9 +2164,9 @@ py::array_t cppComputeForestContainerLeafIndices(ForestContainerCpp& forest } // Compute leaf indices - auto result = py::array_t(py::detail::any_container({num_obs*num_trees, num_samples})); + auto result = py::array_t(py::detail::any_container({num_obs * num_trees, num_samples})); int* output_data_ptr = static_cast(result.mutable_data()); - Eigen::Map> output_eigen(output_data_ptr, num_obs*num_trees, num_samples); + Eigen::Map> output_eigen(output_data_ptr, num_obs * num_trees, num_samples); forest_container.GetContainer()->PredictLeafIndicesInplace(covariates_eigen, output_eigen, forest_indices, num_trees, num_obs); // Return matrix @@ -2180,9 +2191,11 @@ void ForestContainerCpp::AppendFromJson(JsonCpp& json, std::string forest_label) void ForestContainerCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) { // Determine whether or not we are adding forest_num to the residuals std::function op; - if (add) op = std::plus(); - else op = std::minus(); - + if (add) + op = std::plus(); + else + op = std::minus(); + // Perform the update (addition / subtraction) operation StochTree::UpdateResidualEntireForest(*(sampler.GetTracker()), *(dataset.GetDataset()), *(residual.GetData()), forest_samples_->GetEnsemble(forest_num), requires_basis, op); } @@ -2190,9 +2203,11 @@ void ForestContainerCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& void ForestCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, bool add) { // Determine whether or not we are adding forest predictions to the residuals std::function op; - if (add) op = std::plus(); - else op = std::minus(); - + if (add) + op = std::plus(); + else + op = std::minus(); + // Perform the update (addition / subtraction) operation StochTree::UpdateResidualEntireForest(*(sampler.GetTracker()), *(dataset.GetDataset()), *(residual.GetData()), forest_.get(), requires_basis, op); } @@ -2215,13 +2230,13 @@ void RandomEffectsContainerCpp::AddSample(RandomEffectsModelCpp& rfx_model) { py::array_t RandomEffectsContainerCpp::Predict(RandomEffectsDatasetCpp& rfx_dataset, RandomEffectsLabelMapperCpp& label_mapper) { py::ssize_t num_observations = rfx_dataset.NumObservations(); int num_samples = rfx_container_->NumSamples(); - std::vector output(num_observations*num_samples); + std::vector output(num_observations * num_samples); rfx_container_->Predict(*rfx_dataset.GetDataset(), *label_mapper.GetLabelMapper(), output); auto result = py::array_t(py::detail::any_container({num_observations, num_samples})); auto accessor = result.mutable_unchecked<2>(); for (size_t i = 0; i < num_observations; i++) { for (int j = 0; j < num_samples; j++) { - accessor(i, j) = output.at(j*num_observations + i); + accessor(i, j) = output.at(j * num_observations + i); } } return result; @@ -2233,10 +2248,10 @@ void RandomEffectsLabelMapperCpp::LoadFromJson(JsonCpp& json, std::string rfx_la rfx_label_mapper_->from_json(rfx_json); } -void RandomEffectsModelCpp::SampleRandomEffects(RandomEffectsDatasetCpp& rfx_dataset, ResidualCpp& residual, - RandomEffectsTrackerCpp& rfx_tracker, RandomEffectsContainerCpp& rfx_container, +void RandomEffectsModelCpp::SampleRandomEffects(RandomEffectsDatasetCpp& rfx_dataset, ResidualCpp& residual, + RandomEffectsTrackerCpp& rfx_tracker, RandomEffectsContainerCpp& rfx_container, bool keep_sample, double global_variance, RngCpp& rng) { - rfx_model_->SampleRandomEffects(*rfx_dataset.GetDataset(), *residual.GetData(), + rfx_model_->SampleRandomEffects(*rfx_dataset.GetDataset(), *residual.GetData(), *rfx_tracker.GetTracker(), global_variance, *rng.GetRng()); if (keep_sample) rfx_container.AddSample(*this); } @@ -2254,191 +2269,191 @@ PYBIND11_MODULE(stochtree_cpp, m) { m.def("cppComputeForestMaxLeafIndex", &cppComputeForestMaxLeafIndex, "Compute max leaf index of a forest in a forest container"); py::class_(m, "JsonCpp") - .def(py::init<>()) - .def("LoadFile", &JsonCpp::LoadFile) - .def("SaveFile", &JsonCpp::SaveFile) - .def("LoadFromString", &JsonCpp::LoadFromString) - .def("DumpJson", &JsonCpp::DumpJson) - .def("AddDouble", &JsonCpp::AddDouble) - .def("AddDoubleSubfolder", &JsonCpp::AddDoubleSubfolder) - .def("AddInteger", &JsonCpp::AddInteger) - .def("AddIntegerSubfolder", &JsonCpp::AddIntegerSubfolder) - .def("AddBool", &JsonCpp::AddBool) - .def("AddBoolSubfolder", &JsonCpp::AddBoolSubfolder) - .def("AddString", &JsonCpp::AddString) - .def("AddStringSubfolder", &JsonCpp::AddStringSubfolder) - .def("AddDoubleVector", &JsonCpp::AddDoubleVector) - .def("AddDoubleVectorSubfolder", &JsonCpp::AddDoubleVectorSubfolder) - .def("AddIntegerVector", &JsonCpp::AddIntegerVector) - .def("AddIntegerVectorSubfolder", &JsonCpp::AddIntegerVectorSubfolder) - .def("AddStringVector", &JsonCpp::AddStringVector) - .def("AddStringVectorSubfolder", &JsonCpp::AddStringVectorSubfolder) - .def("AddForest", &JsonCpp::AddForest) - .def("AddRandomEffectsContainer", &JsonCpp::AddRandomEffectsContainer) - .def("AddRandomEffectsLabelMapper", &JsonCpp::AddRandomEffectsLabelMapper) - .def("AddRandomEffectsGroupIDs", &JsonCpp::AddRandomEffectsGroupIDs) - .def("ContainsField", &JsonCpp::ContainsField) - .def("ContainsFieldSubfolder", &JsonCpp::ContainsFieldSubfolder) - .def("ExtractDouble", &JsonCpp::ExtractDouble) - .def("ExtractDoubleSubfolder", &JsonCpp::ExtractDoubleSubfolder) - .def("ExtractInteger", &JsonCpp::ExtractInteger) - .def("ExtractIntegerSubfolder", &JsonCpp::ExtractIntegerSubfolder) - .def("ExtractBool", &JsonCpp::ExtractBool) - .def("ExtractBoolSubfolder", &JsonCpp::ExtractBoolSubfolder) - .def("ExtractString", &JsonCpp::ExtractString) - .def("ExtractStringSubfolder", &JsonCpp::ExtractStringSubfolder) - .def("ExtractDoubleVector", &JsonCpp::ExtractDoubleVector) - .def("ExtractDoubleVectorSubfolder", &JsonCpp::ExtractDoubleVectorSubfolder) - .def("ExtractIntegerVector", &JsonCpp::ExtractIntegerVector) - .def("ExtractIntegerVectorSubfolder", &JsonCpp::ExtractIntegerVectorSubfolder) - .def("ExtractStringVector", &JsonCpp::ExtractStringVector) - .def("ExtractStringVectorSubfolder", &JsonCpp::ExtractStringVectorSubfolder) - .def("IncrementRandomEffectsCount", &JsonCpp::IncrementRandomEffectsCount) - .def("SubsetJsonForest", &JsonCpp::SubsetJsonForest) - .def("SubsetJsonRFX", &JsonCpp::SubsetJsonRFX); + .def(py::init<>()) + .def("LoadFile", &JsonCpp::LoadFile) + .def("SaveFile", &JsonCpp::SaveFile) + .def("LoadFromString", &JsonCpp::LoadFromString) + .def("DumpJson", &JsonCpp::DumpJson) + .def("AddDouble", &JsonCpp::AddDouble) + .def("AddDoubleSubfolder", &JsonCpp::AddDoubleSubfolder) + .def("AddInteger", &JsonCpp::AddInteger) + .def("AddIntegerSubfolder", &JsonCpp::AddIntegerSubfolder) + .def("AddBool", &JsonCpp::AddBool) + .def("AddBoolSubfolder", &JsonCpp::AddBoolSubfolder) + .def("AddString", &JsonCpp::AddString) + .def("AddStringSubfolder", &JsonCpp::AddStringSubfolder) + .def("AddDoubleVector", &JsonCpp::AddDoubleVector) + .def("AddDoubleVectorSubfolder", &JsonCpp::AddDoubleVectorSubfolder) + .def("AddIntegerVector", &JsonCpp::AddIntegerVector) + .def("AddIntegerVectorSubfolder", &JsonCpp::AddIntegerVectorSubfolder) + .def("AddStringVector", &JsonCpp::AddStringVector) + .def("AddStringVectorSubfolder", &JsonCpp::AddStringVectorSubfolder) + .def("AddForest", &JsonCpp::AddForest) + .def("AddRandomEffectsContainer", &JsonCpp::AddRandomEffectsContainer) + .def("AddRandomEffectsLabelMapper", &JsonCpp::AddRandomEffectsLabelMapper) + .def("AddRandomEffectsGroupIDs", &JsonCpp::AddRandomEffectsGroupIDs) + .def("ContainsField", &JsonCpp::ContainsField) + .def("ContainsFieldSubfolder", &JsonCpp::ContainsFieldSubfolder) + .def("ExtractDouble", &JsonCpp::ExtractDouble) + .def("ExtractDoubleSubfolder", &JsonCpp::ExtractDoubleSubfolder) + .def("ExtractInteger", &JsonCpp::ExtractInteger) + .def("ExtractIntegerSubfolder", &JsonCpp::ExtractIntegerSubfolder) + .def("ExtractBool", &JsonCpp::ExtractBool) + .def("ExtractBoolSubfolder", &JsonCpp::ExtractBoolSubfolder) + .def("ExtractString", &JsonCpp::ExtractString) + .def("ExtractStringSubfolder", &JsonCpp::ExtractStringSubfolder) + .def("ExtractDoubleVector", &JsonCpp::ExtractDoubleVector) + .def("ExtractDoubleVectorSubfolder", &JsonCpp::ExtractDoubleVectorSubfolder) + .def("ExtractIntegerVector", &JsonCpp::ExtractIntegerVector) + .def("ExtractIntegerVectorSubfolder", &JsonCpp::ExtractIntegerVectorSubfolder) + .def("ExtractStringVector", &JsonCpp::ExtractStringVector) + .def("ExtractStringVectorSubfolder", &JsonCpp::ExtractStringVectorSubfolder) + .def("IncrementRandomEffectsCount", &JsonCpp::IncrementRandomEffectsCount) + .def("SubsetJsonForest", &JsonCpp::SubsetJsonForest) + .def("SubsetJsonRFX", &JsonCpp::SubsetJsonRFX); py::class_(m, "ForestDatasetCpp") - .def(py::init<>()) - .def("AddCovariates", &ForestDatasetCpp::AddCovariates) - .def("AddBasis", &ForestDatasetCpp::AddBasis) - .def("UpdateBasis", &ForestDatasetCpp::UpdateBasis) - .def("AddVarianceWeights", &ForestDatasetCpp::AddVarianceWeights) - .def("UpdateVarianceWeights", &ForestDatasetCpp::UpdateVarianceWeights) - .def("NumRows", &ForestDatasetCpp::NumRows) - .def("NumCovariates", &ForestDatasetCpp::NumCovariates) - .def("NumBasis", &ForestDatasetCpp::NumBasis) - .def("GetCovariates", &ForestDatasetCpp::GetCovariates) - .def("GetBasis", &ForestDatasetCpp::GetBasis) - .def("GetVarianceWeights", &ForestDatasetCpp::GetVarianceWeights) - .def("HasBasis", &ForestDatasetCpp::HasBasis) - .def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights) - .def("AddAuxiliaryDimension", &ForestDatasetCpp::AddAuxiliaryDimension) - .def("SetAuxiliaryDataValue", &ForestDatasetCpp::SetAuxiliaryDataValue) - .def("GetAuxiliaryDataValue", &ForestDatasetCpp::GetAuxiliaryDataValue) - .def("GetAuxiliaryDataVector", &ForestDatasetCpp::GetAuxiliaryDataVector); + .def(py::init<>()) + .def("AddCovariates", &ForestDatasetCpp::AddCovariates) + .def("AddBasis", &ForestDatasetCpp::AddBasis) + .def("UpdateBasis", &ForestDatasetCpp::UpdateBasis) + .def("AddVarianceWeights", &ForestDatasetCpp::AddVarianceWeights) + .def("UpdateVarianceWeights", &ForestDatasetCpp::UpdateVarianceWeights) + .def("NumRows", &ForestDatasetCpp::NumRows) + .def("NumCovariates", &ForestDatasetCpp::NumCovariates) + .def("NumBasis", &ForestDatasetCpp::NumBasis) + .def("GetCovariates", &ForestDatasetCpp::GetCovariates) + .def("GetBasis", &ForestDatasetCpp::GetBasis) + .def("GetVarianceWeights", &ForestDatasetCpp::GetVarianceWeights) + .def("HasBasis", &ForestDatasetCpp::HasBasis) + .def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights) + .def("AddAuxiliaryDimension", &ForestDatasetCpp::AddAuxiliaryDimension) + .def("SetAuxiliaryDataValue", &ForestDatasetCpp::SetAuxiliaryDataValue) + .def("GetAuxiliaryDataValue", &ForestDatasetCpp::GetAuxiliaryDataValue) + .def("GetAuxiliaryDataVector", &ForestDatasetCpp::GetAuxiliaryDataVector); py::class_(m, "ResidualCpp") - .def(py::init,data_size_t>()) - .def("GetResidualArray", &ResidualCpp::GetResidualArray) - .def("ReplaceData", &ResidualCpp::ReplaceData) - .def("AddToData", &ResidualCpp::AddToData) - .def("SubtractFromData", &ResidualCpp::SubtractFromData); + .def(py::init, data_size_t>()) + .def("GetResidualArray", &ResidualCpp::GetResidualArray) + .def("ReplaceData", &ResidualCpp::ReplaceData) + .def("AddToData", &ResidualCpp::AddToData) + .def("SubtractFromData", &ResidualCpp::SubtractFromData); py::class_(m, "RngCpp") - .def(py::init()); - + .def(py::init()); + py::class_(m, "ForestContainerCpp") - .def(py::init()) - .def("CombineForests", &ForestContainerCpp::CombineForests) - .def("AddToForest", &ForestContainerCpp::AddToForest) - .def("MultiplyForest", &ForestContainerCpp::MultiplyForest) - .def("OutputDimension", &ForestContainerCpp::OutputDimension) - .def("NumTrees", &ForestContainerCpp::NumTrees) - .def("NumSamples", &ForestContainerCpp::NumSamples) - .def("DeleteSample", &ForestContainerCpp::DeleteSample) - .def("Predict", &ForestContainerCpp::Predict) - .def("PredictRaw", &ForestContainerCpp::PredictRaw) - .def("PredictRawSingleForest", &ForestContainerCpp::PredictRawSingleForest) - .def("SetRootValue", &ForestContainerCpp::SetRootValue) - .def("SetRootVector", &ForestContainerCpp::SetRootVector) - .def("AdjustResidual", &ForestContainerCpp::AdjustResidual) - .def("SaveToJsonFile", &ForestContainerCpp::SaveToJsonFile) - .def("LoadFromJsonFile", &ForestContainerCpp::LoadFromJsonFile) - .def("LoadFromJson", &ForestContainerCpp::LoadFromJson) - .def("AppendFromJson", &ForestContainerCpp::AppendFromJson) - .def("DumpJsonString", &ForestContainerCpp::DumpJsonString) - .def("LoadFromJsonString", &ForestContainerCpp::LoadFromJsonString) - .def("AddSampleValue", &ForestContainerCpp::AddSampleValue) - .def("AddSampleVector", &ForestContainerCpp::AddSampleVector) - .def("AddNumericSplitValue", &ForestContainerCpp::AddNumericSplitValue) - .def("AddNumericSplitVector", &ForestContainerCpp::AddNumericSplitVector) - .def("GetTreeLeaves", &ForestContainerCpp::GetTreeLeaves) - .def("GetTreeSplitCounts", &ForestContainerCpp::GetTreeSplitCounts) - .def("GetForestSplitCounts", &ForestContainerCpp::GetForestSplitCounts) - .def("GetOverallSplitCounts", &ForestContainerCpp::GetOverallSplitCounts) - .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts) - .def("NumLeavesForest", &ForestContainerCpp::NumLeavesForest) - .def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared) - .def("IsLeafNode", &ForestContainerCpp::IsLeafNode) - .def("IsNumericSplitNode", &ForestContainerCpp::IsNumericSplitNode) - .def("IsCategoricalSplitNode", &ForestContainerCpp::IsCategoricalSplitNode) - .def("ParentNode", &ForestContainerCpp::ParentNode) - .def("LeftChildNode", &ForestContainerCpp::LeftChildNode) - .def("RightChildNode", &ForestContainerCpp::RightChildNode) - .def("SplitIndex", &ForestContainerCpp::SplitIndex) - .def("NodeDepth", &ForestContainerCpp::NodeDepth) - .def("SplitThreshold", &ForestContainerCpp::SplitThreshold) - .def("SplitCategories", &ForestContainerCpp::SplitCategories) - .def("NodeLeafValues", &ForestContainerCpp::NodeLeafValues) - .def("NumNodes", &ForestContainerCpp::NumNodes) - .def("NumLeaves", &ForestContainerCpp::NumLeaves) - .def("NumLeafParents", &ForestContainerCpp::NumLeafParents) - .def("NumSplitNodes", &ForestContainerCpp::NumSplitNodes) - .def("Nodes", &ForestContainerCpp::Nodes) - .def("Leaves", &ForestContainerCpp::Leaves); + .def(py::init()) + .def("CombineForests", &ForestContainerCpp::CombineForests) + .def("AddToForest", &ForestContainerCpp::AddToForest) + .def("MultiplyForest", &ForestContainerCpp::MultiplyForest) + .def("OutputDimension", &ForestContainerCpp::OutputDimension) + .def("NumTrees", &ForestContainerCpp::NumTrees) + .def("NumSamples", &ForestContainerCpp::NumSamples) + .def("DeleteSample", &ForestContainerCpp::DeleteSample) + .def("Predict", &ForestContainerCpp::Predict) + .def("PredictRaw", &ForestContainerCpp::PredictRaw) + .def("PredictRawSingleForest", &ForestContainerCpp::PredictRawSingleForest) + .def("SetRootValue", &ForestContainerCpp::SetRootValue) + .def("SetRootVector", &ForestContainerCpp::SetRootVector) + .def("AdjustResidual", &ForestContainerCpp::AdjustResidual) + .def("SaveToJsonFile", &ForestContainerCpp::SaveToJsonFile) + .def("LoadFromJsonFile", &ForestContainerCpp::LoadFromJsonFile) + .def("LoadFromJson", &ForestContainerCpp::LoadFromJson) + .def("AppendFromJson", &ForestContainerCpp::AppendFromJson) + .def("DumpJsonString", &ForestContainerCpp::DumpJsonString) + .def("LoadFromJsonString", &ForestContainerCpp::LoadFromJsonString) + .def("AddSampleValue", &ForestContainerCpp::AddSampleValue) + .def("AddSampleVector", &ForestContainerCpp::AddSampleVector) + .def("AddNumericSplitValue", &ForestContainerCpp::AddNumericSplitValue) + .def("AddNumericSplitVector", &ForestContainerCpp::AddNumericSplitVector) + .def("GetTreeLeaves", &ForestContainerCpp::GetTreeLeaves) + .def("GetTreeSplitCounts", &ForestContainerCpp::GetTreeSplitCounts) + .def("GetForestSplitCounts", &ForestContainerCpp::GetForestSplitCounts) + .def("GetOverallSplitCounts", &ForestContainerCpp::GetOverallSplitCounts) + .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts) + .def("NumLeavesForest", &ForestContainerCpp::NumLeavesForest) + .def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared) + .def("IsLeafNode", &ForestContainerCpp::IsLeafNode) + .def("IsNumericSplitNode", &ForestContainerCpp::IsNumericSplitNode) + .def("IsCategoricalSplitNode", &ForestContainerCpp::IsCategoricalSplitNode) + .def("ParentNode", &ForestContainerCpp::ParentNode) + .def("LeftChildNode", &ForestContainerCpp::LeftChildNode) + .def("RightChildNode", &ForestContainerCpp::RightChildNode) + .def("SplitIndex", &ForestContainerCpp::SplitIndex) + .def("NodeDepth", &ForestContainerCpp::NodeDepth) + .def("SplitThreshold", &ForestContainerCpp::SplitThreshold) + .def("SplitCategories", &ForestContainerCpp::SplitCategories) + .def("NodeLeafValues", &ForestContainerCpp::NodeLeafValues) + .def("NumNodes", &ForestContainerCpp::NumNodes) + .def("NumLeaves", &ForestContainerCpp::NumLeaves) + .def("NumLeafParents", &ForestContainerCpp::NumLeafParents) + .def("NumSplitNodes", &ForestContainerCpp::NumSplitNodes) + .def("Nodes", &ForestContainerCpp::Nodes) + .def("Leaves", &ForestContainerCpp::Leaves); py::class_(m, "ForestCpp") - .def(py::init()) - .def("GetForestPtr", &ForestCpp::GetForestPtr) - .def("MergeForest", &ForestCpp::MergeForest) - .def("AddConstant", &ForestCpp::AddConstant) - .def("MultiplyConstant", &ForestCpp::MultiplyConstant) - .def("OutputDimension", &ForestCpp::OutputDimension) - .def("NumTrees", &ForestCpp::NumTrees) - .def("NumLeavesForest", &ForestCpp::NumLeavesForest) - .def("SumLeafSquared", &ForestCpp::SumLeafSquared) - .def("ResetRoot", &ForestCpp::ResetRoot) - .def("Reset", &ForestCpp::Reset) - .def("Predict", &ForestCpp::Predict) - .def("PredictRaw", &ForestCpp::PredictRaw) - .def("SetRootValue", &ForestCpp::SetRootValue) - .def("SetRootVector", &ForestCpp::SetRootVector) - .def("AdjustResidual", &ForestCpp::AdjustResidual) - .def("AddNumericSplitValue", &ForestCpp::AddNumericSplitValue) - .def("AddNumericSplitVector", &ForestCpp::AddNumericSplitVector) - .def("GetEnsemble", &ForestCpp::GetEnsemble) - .def("GetTreeLeaves", &ForestCpp::GetTreeLeaves) - .def("GetTreeSplitCounts", &ForestCpp::GetTreeSplitCounts) - .def("GetOverallSplitCounts", &ForestCpp::GetOverallSplitCounts) - .def("GetGranularSplitCounts", &ForestCpp::GetGranularSplitCounts) - .def("NumLeavesForest", &ForestCpp::NumLeavesForest) - .def("SumLeafSquared", &ForestCpp::SumLeafSquared) - .def("IsLeafNode", &ForestCpp::IsLeafNode) - .def("IsNumericSplitNode", &ForestCpp::IsNumericSplitNode) - .def("IsCategoricalSplitNode", &ForestCpp::IsCategoricalSplitNode) - .def("ParentNode", &ForestCpp::ParentNode) - .def("LeftChildNode", &ForestCpp::LeftChildNode) - .def("RightChildNode", &ForestCpp::RightChildNode) - .def("SplitIndex", &ForestCpp::SplitIndex) - .def("NodeDepth", &ForestCpp::NodeDepth) - .def("SplitThreshold", &ForestCpp::SplitThreshold) - .def("SplitCategories", &ForestCpp::SplitCategories) - .def("NodeLeafValues", &ForestCpp::NodeLeafValues) - .def("NumNodes", &ForestCpp::NumNodes) - .def("NumLeaves", &ForestCpp::NumLeaves) - .def("NumLeafParents", &ForestCpp::NumLeafParents) - .def("NumSplitNodes", &ForestCpp::NumSplitNodes) - .def("Nodes", &ForestCpp::Nodes) - .def("Leaves", &ForestCpp::Leaves); - + .def(py::init()) + .def("GetForestPtr", &ForestCpp::GetForestPtr) + .def("MergeForest", &ForestCpp::MergeForest) + .def("AddConstant", &ForestCpp::AddConstant) + .def("MultiplyConstant", &ForestCpp::MultiplyConstant) + .def("OutputDimension", &ForestCpp::OutputDimension) + .def("NumTrees", &ForestCpp::NumTrees) + .def("NumLeavesForest", &ForestCpp::NumLeavesForest) + .def("SumLeafSquared", &ForestCpp::SumLeafSquared) + .def("ResetRoot", &ForestCpp::ResetRoot) + .def("Reset", &ForestCpp::Reset) + .def("Predict", &ForestCpp::Predict) + .def("PredictRaw", &ForestCpp::PredictRaw) + .def("SetRootValue", &ForestCpp::SetRootValue) + .def("SetRootVector", &ForestCpp::SetRootVector) + .def("AdjustResidual", &ForestCpp::AdjustResidual) + .def("AddNumericSplitValue", &ForestCpp::AddNumericSplitValue) + .def("AddNumericSplitVector", &ForestCpp::AddNumericSplitVector) + .def("GetEnsemble", &ForestCpp::GetEnsemble) + .def("GetTreeLeaves", &ForestCpp::GetTreeLeaves) + .def("GetTreeSplitCounts", &ForestCpp::GetTreeSplitCounts) + .def("GetOverallSplitCounts", &ForestCpp::GetOverallSplitCounts) + .def("GetGranularSplitCounts", &ForestCpp::GetGranularSplitCounts) + .def("NumLeavesForest", &ForestCpp::NumLeavesForest) + .def("SumLeafSquared", &ForestCpp::SumLeafSquared) + .def("IsLeafNode", &ForestCpp::IsLeafNode) + .def("IsNumericSplitNode", &ForestCpp::IsNumericSplitNode) + .def("IsCategoricalSplitNode", &ForestCpp::IsCategoricalSplitNode) + .def("ParentNode", &ForestCpp::ParentNode) + .def("LeftChildNode", &ForestCpp::LeftChildNode) + .def("RightChildNode", &ForestCpp::RightChildNode) + .def("SplitIndex", &ForestCpp::SplitIndex) + .def("NodeDepth", &ForestCpp::NodeDepth) + .def("SplitThreshold", &ForestCpp::SplitThreshold) + .def("SplitCategories", &ForestCpp::SplitCategories) + .def("NodeLeafValues", &ForestCpp::NodeLeafValues) + .def("NumNodes", &ForestCpp::NumNodes) + .def("NumLeaves", &ForestCpp::NumLeaves) + .def("NumLeafParents", &ForestCpp::NumLeafParents) + .def("NumSplitNodes", &ForestCpp::NumSplitNodes) + .def("Nodes", &ForestCpp::Nodes) + .def("Leaves", &ForestCpp::Leaves); + py::class_(m, "ForestSamplerCpp") - .def(py::init, int, data_size_t, double, double, int, int>()) - .def("ReconstituteTrackerFromForest", &ForestSamplerCpp::ReconstituteTrackerFromForest) - .def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration) - .def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel) - .def("GetCachedForestPredictions", &ForestSamplerCpp::GetCachedForestPredictions) - .def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate) - .def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate) - .def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha) - .def("UpdateBeta", &ForestSamplerCpp::UpdateBeta) - .def("UpdateMinSamplesLeaf", &ForestSamplerCpp::UpdateMinSamplesLeaf) - .def("UpdateMaxDepth", &ForestSamplerCpp::UpdateMaxDepth) - .def("GetAlpha", &ForestSamplerCpp::GetAlpha) - .def("GetBeta", &ForestSamplerCpp::GetBeta) - .def("GetMinSamplesLeaf", &ForestSamplerCpp::GetMinSamplesLeaf) - .def("GetMaxDepth", &ForestSamplerCpp::GetMaxDepth); - - py::class_(m, "RandomEffectsDatasetCpp") + .def(py::init, int, data_size_t, double, double, int, int>()) + .def("ReconstituteTrackerFromForest", &ForestSamplerCpp::ReconstituteTrackerFromForest) + .def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration) + .def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel) + .def("GetCachedForestPredictions", &ForestSamplerCpp::GetCachedForestPredictions) + .def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate) + .def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate) + .def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha) + .def("UpdateBeta", &ForestSamplerCpp::UpdateBeta) + .def("UpdateMinSamplesLeaf", &ForestSamplerCpp::UpdateMinSamplesLeaf) + .def("UpdateMaxDepth", &ForestSamplerCpp::UpdateMaxDepth) + .def("GetAlpha", &ForestSamplerCpp::GetAlpha) + .def("GetBeta", &ForestSamplerCpp::GetBeta) + .def("GetMinSamplesLeaf", &ForestSamplerCpp::GetMinSamplesLeaf) + .def("GetMaxDepth", &ForestSamplerCpp::GetMaxDepth); + + py::class_(m, "RandomEffectsDatasetCpp") .def(py::init<>()) .def("GetDataset", &RandomEffectsDatasetCpp::GetDataset) .def("NumObservations", &RandomEffectsDatasetCpp::NumObservations) @@ -2457,71 +2472,71 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("HasVarianceWeights", &RandomEffectsDatasetCpp::HasVarianceWeights); py::class_(m, "RandomEffectsContainerCpp") - .def(py::init<>()) - .def("SetComponentsAndGroups", &RandomEffectsContainerCpp::SetComponentsAndGroups) - .def("AddSample", &RandomEffectsContainerCpp::AddSample) - .def("NumSamples", &RandomEffectsContainerCpp::NumSamples) - .def("NumComponents", &RandomEffectsContainerCpp::NumComponents) - .def("NumGroups", &RandomEffectsContainerCpp::NumGroups) - .def("GetBeta", &RandomEffectsContainerCpp::GetBeta) - .def("GetXi", &RandomEffectsContainerCpp::GetXi) - .def("GetAlpha", &RandomEffectsContainerCpp::GetAlpha) - .def("GetSigma", &RandomEffectsContainerCpp::GetSigma) - .def("DeleteSample", &RandomEffectsContainerCpp::DeleteSample) - .def("Predict", &RandomEffectsContainerCpp::Predict) - .def("SaveToJsonFile", &RandomEffectsContainerCpp::SaveToJsonFile) - .def("LoadFromJsonFile", &RandomEffectsContainerCpp::LoadFromJsonFile) - .def("DumpJsonString", &RandomEffectsContainerCpp::DumpJsonString) - .def("LoadFromJsonString", &RandomEffectsContainerCpp::LoadFromJsonString) - .def("LoadFromJson", &RandomEffectsContainerCpp::LoadFromJson) - .def("AppendFromJson", &RandomEffectsContainerCpp::AppendFromJson) - .def("GetRandomEffectsContainer", &RandomEffectsContainerCpp::GetRandomEffectsContainer); + .def(py::init<>()) + .def("SetComponentsAndGroups", &RandomEffectsContainerCpp::SetComponentsAndGroups) + .def("AddSample", &RandomEffectsContainerCpp::AddSample) + .def("NumSamples", &RandomEffectsContainerCpp::NumSamples) + .def("NumComponents", &RandomEffectsContainerCpp::NumComponents) + .def("NumGroups", &RandomEffectsContainerCpp::NumGroups) + .def("GetBeta", &RandomEffectsContainerCpp::GetBeta) + .def("GetXi", &RandomEffectsContainerCpp::GetXi) + .def("GetAlpha", &RandomEffectsContainerCpp::GetAlpha) + .def("GetSigma", &RandomEffectsContainerCpp::GetSigma) + .def("DeleteSample", &RandomEffectsContainerCpp::DeleteSample) + .def("Predict", &RandomEffectsContainerCpp::Predict) + .def("SaveToJsonFile", &RandomEffectsContainerCpp::SaveToJsonFile) + .def("LoadFromJsonFile", &RandomEffectsContainerCpp::LoadFromJsonFile) + .def("DumpJsonString", &RandomEffectsContainerCpp::DumpJsonString) + .def("LoadFromJsonString", &RandomEffectsContainerCpp::LoadFromJsonString) + .def("LoadFromJson", &RandomEffectsContainerCpp::LoadFromJson) + .def("AppendFromJson", &RandomEffectsContainerCpp::AppendFromJson) + .def("GetRandomEffectsContainer", &RandomEffectsContainerCpp::GetRandomEffectsContainer); py::class_(m, "RandomEffectsTrackerCpp") - .def(py::init>()) - .def("GetUniqueGroupIds", &RandomEffectsTrackerCpp::GetUniqueGroupIds) - .def("GetTracker", &RandomEffectsTrackerCpp::GetTracker) - .def("Reset", &RandomEffectsTrackerCpp::Reset) - .def("RootReset", &RandomEffectsTrackerCpp::RootReset); + .def(py::init>()) + .def("GetUniqueGroupIds", &RandomEffectsTrackerCpp::GetUniqueGroupIds) + .def("GetTracker", &RandomEffectsTrackerCpp::GetTracker) + .def("Reset", &RandomEffectsTrackerCpp::Reset) + .def("RootReset", &RandomEffectsTrackerCpp::RootReset); py::class_(m, "RandomEffectsLabelMapperCpp") - .def(py::init<>()) - .def("LoadFromTracker", &RandomEffectsLabelMapperCpp::LoadFromTracker) - .def("SaveToJsonFile", &RandomEffectsLabelMapperCpp::SaveToJsonFile) - .def("LoadFromJsonFile", &RandomEffectsLabelMapperCpp::LoadFromJsonFile) - .def("DumpJsonString", &RandomEffectsLabelMapperCpp::DumpJsonString) - .def("LoadFromJsonString", &RandomEffectsLabelMapperCpp::LoadFromJsonString) - .def("LoadFromJson", &RandomEffectsLabelMapperCpp::LoadFromJson) - .def("GetLabelMapper", &RandomEffectsLabelMapperCpp::GetLabelMapper) - .def("MapGroupIdToArrayIndex", &RandomEffectsLabelMapperCpp::MapGroupIdToArrayIndex) - .def("MapMultipleGroupIdsToArrayIndices", &RandomEffectsLabelMapperCpp::MapMultipleGroupIdsToArrayIndices); + .def(py::init<>()) + .def("LoadFromTracker", &RandomEffectsLabelMapperCpp::LoadFromTracker) + .def("SaveToJsonFile", &RandomEffectsLabelMapperCpp::SaveToJsonFile) + .def("LoadFromJsonFile", &RandomEffectsLabelMapperCpp::LoadFromJsonFile) + .def("DumpJsonString", &RandomEffectsLabelMapperCpp::DumpJsonString) + .def("LoadFromJsonString", &RandomEffectsLabelMapperCpp::LoadFromJsonString) + .def("LoadFromJson", &RandomEffectsLabelMapperCpp::LoadFromJson) + .def("GetLabelMapper", &RandomEffectsLabelMapperCpp::GetLabelMapper) + .def("MapGroupIdToArrayIndex", &RandomEffectsLabelMapperCpp::MapGroupIdToArrayIndex) + .def("MapMultipleGroupIdsToArrayIndices", &RandomEffectsLabelMapperCpp::MapMultipleGroupIdsToArrayIndices); py::class_(m, "RandomEffectsModelCpp") - .def(py::init()) - .def("GetModel", &RandomEffectsModelCpp::GetModel) - .def("SampleRandomEffects", &RandomEffectsModelCpp::SampleRandomEffects) - .def("Predict", &RandomEffectsModelCpp::Predict) - .def("SetWorkingParameter", &RandomEffectsModelCpp::SetWorkingParameter) - .def("SetGroupParameters", &RandomEffectsModelCpp::SetGroupParameters) - .def("SetWorkingParameterCovariance", &RandomEffectsModelCpp::SetWorkingParameterCovariance) - .def("SetGroupParameterCovariance", &RandomEffectsModelCpp::SetGroupParameterCovariance) - .def("SetVariancePriorShape", &RandomEffectsModelCpp::SetVariancePriorShape) - .def("SetVariancePriorScale", &RandomEffectsModelCpp::SetVariancePriorScale) - .def("Reset", &RandomEffectsModelCpp::Reset); + .def(py::init()) + .def("GetModel", &RandomEffectsModelCpp::GetModel) + .def("SampleRandomEffects", &RandomEffectsModelCpp::SampleRandomEffects) + .def("Predict", &RandomEffectsModelCpp::Predict) + .def("SetWorkingParameter", &RandomEffectsModelCpp::SetWorkingParameter) + .def("SetGroupParameters", &RandomEffectsModelCpp::SetGroupParameters) + .def("SetWorkingParameterCovariance", &RandomEffectsModelCpp::SetWorkingParameterCovariance) + .def("SetGroupParameterCovariance", &RandomEffectsModelCpp::SetGroupParameterCovariance) + .def("SetVariancePriorShape", &RandomEffectsModelCpp::SetVariancePriorShape) + .def("SetVariancePriorScale", &RandomEffectsModelCpp::SetVariancePriorScale) + .def("Reset", &RandomEffectsModelCpp::Reset); py::class_(m, "GlobalVarianceModelCpp") - .def(py::init<>()) - .def("SampleOneIteration", &GlobalVarianceModelCpp::SampleOneIteration); + .def(py::init<>()) + .def("SampleOneIteration", &GlobalVarianceModelCpp::SampleOneIteration); py::class_(m, "LeafVarianceModelCpp") - .def(py::init<>()) - .def("SampleOneIteration", &LeafVarianceModelCpp::SampleOneIteration); + .def(py::init<>()) + .def("SampleOneIteration", &LeafVarianceModelCpp::SampleOneIteration); py::class_(m, "OrdinalSamplerCpp") - .def(py::init<>()) - .def("UpdateLatentVariables", &OrdinalSamplerCpp::UpdateLatentVariables) - .def("UpdateGammaParams", &OrdinalSamplerCpp::UpdateGammaParams) - .def("UpdateCumulativeExpSums", &OrdinalSamplerCpp::UpdateCumulativeExpSums); + .def(py::init<>()) + .def("UpdateLatentVariables", &OrdinalSamplerCpp::UpdateLatentVariables) + .def("UpdateGammaParams", &OrdinalSamplerCpp::UpdateGammaParams) + .def("UpdateCumulativeExpSums", &OrdinalSamplerCpp::UpdateCumulativeExpSums); #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); diff --git a/src/random_effects.cpp b/src/random_effects.cpp index 40c828ba..54f5b52f 100644 --- a/src/random_effects.cpp +++ b/src/random_effects.cpp @@ -39,7 +39,7 @@ void LabelMapper::from_json(const nlohmann::json& rfx_label_mapper_json) { } } -void RandomEffectsTracker::ResetFromSample(MultivariateRegressionRandomEffectsModel& rfx_model, +void RandomEffectsTracker::ResetFromSample(MultivariateRegressionRandomEffectsModel& rfx_model, RandomEffectsDataset& rfx_dataset, ColumnVector& residual) { Eigen::MatrixXd X = rfx_dataset.GetBasis(); std::vector group_labels = rfx_dataset.GetGroupLabels(); @@ -61,7 +61,7 @@ void RandomEffectsTracker::ResetFromSample(MultivariateRegressionRandomEffectsMo } } -void RandomEffectsTracker::RootReset(MultivariateRegressionRandomEffectsModel& rfx_model, +void RandomEffectsTracker::RootReset(MultivariateRegressionRandomEffectsModel& rfx_model, RandomEffectsDataset& rfx_dataset, ColumnVector& residual) { int n = rfx_dataset.NumObservations(); CHECK_EQ(n, num_observations_); @@ -78,26 +78,26 @@ void RandomEffectsTracker::RootReset(MultivariateRegressionRandomEffectsModel& r } void MultivariateRegressionRandomEffectsModel::ResetFromSample(RandomEffectsContainer& rfx_container, int sample_num) { - // Extract parameter vectors - std::vector& alpha = rfx_container.GetAlpha(); - std::vector& xi = rfx_container.GetXi(); - std::vector& sigma = rfx_container.GetSigma(); - - // Unpack parameters - for (int i = 0; i < num_components_; i++) { - working_parameter_(i) = alpha.at(sample_num*num_components_ + i); - group_parameter_covariance_(i, i) = sigma.at(sample_num*num_components_ + i); - for (int j = 0; j < num_groups_; j++) { - group_parameters_(i,j) = xi.at(sample_num*num_groups_*num_components_ + j*num_components_ + i); - } + // Extract parameter vectors + std::vector& alpha = rfx_container.GetAlpha(); + std::vector& xi = rfx_container.GetXi(); + std::vector& sigma = rfx_container.GetSigma(); + + // Unpack parameters + for (int i = 0; i < num_components_; i++) { + working_parameter_(i) = alpha.at(sample_num * num_components_ + i); + group_parameter_covariance_(i, i) = sigma.at(sample_num * num_components_ + i); + for (int j = 0; j < num_groups_; j++) { + group_parameters_(i, j) = xi.at(sample_num * num_groups_ * num_components_ + j * num_components_ + i); } } +} -void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, +void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { // Update partial residual to add back in the random effects AddCurrentPredictionToResidual(dataset, rfx_tracker, residual); - + // Sample random effects SampleGroupParameters(dataset, residual, rfx_tracker, global_variance, gen); SampleWorkingParameter(dataset, residual, rfx_tracker, global_variance, gen); @@ -107,14 +107,14 @@ void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffects SubtractNewPredictionFromResidual(dataset, rfx_tracker, residual); } -void MultivariateRegressionRandomEffectsModel::SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, +void MultivariateRegressionRandomEffectsModel::SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { Eigen::VectorXd posterior_mean = WorkingParameterMean(dataset, residual, rfx_tracker, global_variance); Eigen::MatrixXd posterior_covariance = WorkingParameterVariance(dataset, residual, rfx_tracker, global_variance); working_parameter_ = normal_sampler_.SampleEigen(posterior_mean, posterior_covariance, gen); } -void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffectsDataset& dataset, ColumnVector& residual, +void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { int32_t num_groups = num_groups_; Eigen::VectorXd posterior_mean; @@ -124,10 +124,10 @@ void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffec posterior_mean = GroupParameterMean(dataset, residual, rfx_tracker, global_variance, i); posterior_covariance = GroupParameterVariance(dataset, residual, rfx_tracker, global_variance, i); group_parameters_(Eigen::all, i) = normal_sampler_.SampleEigen(posterior_mean, posterior_covariance, gen); - } + } } -void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEffectsDataset& dataset, ColumnVector& residual, +void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { int32_t num_components = num_components_; double posterior_shape; @@ -140,8 +140,8 @@ void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEf } } -Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, - double global_variance){ +Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, + double global_variance) { int32_t num_components = num_components_; int32_t num_groups = num_groups_; std::vector observation_indices; @@ -164,7 +164,7 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(R return posterior_denominator.inverse() * posterior_numerator; } -Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance){ +Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance) { int32_t num_components = num_components_; int32_t num_groups = num_groups_; std::vector observation_indices; @@ -202,19 +202,19 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(Ran return posterior_denominator.inverse() * posterior_numerator; } -Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id){ +Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id) { int32_t num_components = num_components_; int32_t num_groups = num_groups_; Eigen::MatrixXd X = dataset.GetBasis(); Eigen::VectorXd y = residual.GetData(); Eigen::VectorXd alpha = working_parameter_; Eigen::MatrixXd posterior_denominator = group_parameter_covariance_.inverse(); -// Eigen::VectorXd posterior_numerator = Eigen::VectorXd::Zero(num_components); + // Eigen::VectorXd posterior_numerator = Eigen::VectorXd::Zero(num_components); std::vector observation_indices = rfx_tracker.NodeIndicesInternalIndex(group_id); Eigen::MatrixXd X_group = X(observation_indices, Eigen::all); -// Eigen::VectorXd y_group = y(observation_indices, Eigen::all); + // Eigen::VectorXd y_group = y(observation_indices, Eigen::all); posterior_denominator += ((alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal()) / (global_variance); -// posterior_numerator += (alpha).asDiagonal() * X_group.transpose() * y_group; + // posterior_numerator += (alpha).asDiagonal() * X_group.transpose() * y_group; return posterior_denominator.inverse(); } @@ -227,36 +227,36 @@ double MultivariateRegressionRandomEffectsModel::VarianceComponentScale(RandomEf Eigen::MatrixXd xi = group_parameters_; double output = variance_prior_scale_; for (int i = 0; i < num_groups; i++) { - output += xi(component_id, i)*xi(component_id, i); + output += xi(component_id, i) * xi(component_id, i); } return output; } -void RandomEffectsContainer::AddSample(MultivariateRegressionRandomEffectsModel& model){ +void RandomEffectsContainer::AddSample(MultivariateRegressionRandomEffectsModel& model) { // Increment number of samples int sample_ind = num_samples_; num_samples_++; // Add alpha - alpha_.resize(num_samples_*num_components_); + alpha_.resize(num_samples_ * num_components_); for (int i = 0; i < num_components_; i++) { - alpha_.at(sample_ind*num_components_ + i) = model.GetWorkingParameter()(i); + alpha_.at(sample_ind * num_components_ + i) = model.GetWorkingParameter()(i); } // Add xi and beta - xi_.resize(num_samples_*num_components_*num_groups_); - beta_.resize(num_samples_*num_components_*num_groups_); + xi_.resize(num_samples_ * num_components_ * num_groups_); + beta_.resize(num_samples_ * num_components_ * num_groups_); for (int i = 0; i < num_components_; i++) { for (int j = 0; j < num_groups_; j++) { - xi_.at(sample_ind*num_groups_*num_components_ + j*num_components_ + i) = model.GetGroupParameters()(i,j); - beta_.at(sample_ind*num_groups_*num_components_ + j*num_components_ + i) = xi_.at(sample_ind*num_groups_*num_components_ + j*num_components_ + i) * alpha_.at(sample_ind*num_components_ + i); + xi_.at(sample_ind * num_groups_ * num_components_ + j * num_components_ + i) = model.GetGroupParameters()(i, j); + beta_.at(sample_ind * num_groups_ * num_components_ + j * num_components_ + i) = xi_.at(sample_ind * num_groups_ * num_components_ + j * num_components_ + i) * alpha_.at(sample_ind * num_components_ + i); } } // Add sigma - sigma_xi_.resize(num_samples_*num_components_); + sigma_xi_.resize(num_samples_ * num_components_); for (int i = 0; i < num_components_; i++) { - sigma_xi_.at(sample_ind*num_components_ + i) = model.GetGroupParameterCovariance()(i,i); + sigma_xi_.at(sample_ind * num_components_ + i) = model.GetGroupParameterCovariance()(i, i); } } @@ -265,7 +265,7 @@ void RandomEffectsContainer::Predict(RandomEffectsDataset& dataset, LabelMapper& std::vector group_labels = dataset.GetGroupLabels(); CHECK_EQ(X.rows(), group_labels.size()); int n = X.rows(); - CHECK_EQ(n*num_samples_, output.size()); + CHECK_EQ(n * num_samples_, output.size()); std::int32_t group_ind; double pred; for (int i = 0; i < n; i++) { @@ -273,9 +273,9 @@ void RandomEffectsContainer::Predict(RandomEffectsDataset& dataset, LabelMapper& for (int j = 0; j < num_samples_; j++) { pred = 0; for (int k = 0; k < num_components_; k++) { - pred += X(i,k) * beta_.at(j*num_groups_*num_components_ + group_ind*num_components_ + k); + pred += X(i, k) * beta_.at(j * num_groups_ * num_components_ + group_ind * num_components_ + k); } - output.at(j*n + i) = pred; + output.at(j * n + i) = pred; } } } @@ -288,8 +288,8 @@ nlohmann::json RandomEffectsContainer::to_json() { result_obj.emplace("num_groups", num_groups_); // Store some meta-level information about the containers - int beta_size = num_groups_*num_components_*num_samples_; - int alpha_size = num_components_*num_samples_; + int beta_size = num_groups_ * num_components_ * num_samples_; + int alpha_size = num_components_ * num_samples_; result_obj.emplace("beta_size", beta_size); result_obj.emplace("alpha_size", alpha_size); @@ -317,39 +317,39 @@ nlohmann::json RandomEffectsContainer::to_json() { result_obj.emplace(pair); } -return result_obj; + return result_obj; } -void RandomEffectsContainer::DeleteSample(int sample_num){ +void RandomEffectsContainer::DeleteSample(int sample_num) { // Decrement number of samples num_samples_--; // Remove sample_num from alpha // ---------------------------- - // This code works because the data are stored in a "column-major" format, - // with components comprising rows and and samples comprising columns, so that - // element `sample_num*num_components_ + i` will contain the "i"-th component of the - // sample indexed by sample_num. Erasing the `sample_num*num_components_ + 0` - // element of the vector will move the element that was previously in position + // This code works because the data are stored in a "column-major" format, + // with components comprising rows and and samples comprising columns, so that + // element `sample_num*num_components_ + i` will contain the "i"-th component of the + // sample indexed by sample_num. Erasing the `sample_num*num_components_ + 0` + // element of the vector will move the element that was previously in position // `sample_num*num_components_ + 1` into the position `sample_num*num_components_ + 0` // and thus we can repeat `alpha_.erase(alpha_.begin() + sample_num*num_components_);` // exactly `num_components_` times to erase each component pertaining to this sample. for (int i = 0; i < num_components_; i++) { - alpha_.erase(alpha_.begin() + sample_num*num_components_); + alpha_.erase(alpha_.begin() + sample_num * num_components_); } // Remove sample_num from xi and beta // ---------------------------------- - // This code works as above, with the added nuance of the three-dimensional (Fortran-aligned) array, - // in which sample number is the third dimension, group number is the second dimension, and component - // number is the third dimension. The nested loop assembles all `num_groups_*num_components_` offsets, - // expressed as `j*num_components_ + i`. In order to remove each of the elements stored in these offsets - // from `sample_num*num_groups_*num_components_`, we simply need to erase the + // This code works as above, with the added nuance of the three-dimensional (Fortran-aligned) array, + // in which sample number is the third dimension, group number is the second dimension, and component + // number is the third dimension. The nested loop assembles all `num_groups_*num_components_` offsets, + // expressed as `j*num_components_ + i`. In order to remove each of the elements stored in these offsets + // from `sample_num*num_groups_*num_components_`, we simply need to erase the // `sample_num*num_groups_*num_components_` element, exactly `num_groups_*num_components_` times. for (int i = 0; i < num_components_; i++) { for (int j = 0; j < num_groups_; j++) { - xi_.erase(xi_.begin() + sample_num*num_groups_*num_components_); - beta_.erase(beta_.begin() + sample_num*num_groups_*num_components_); + xi_.erase(xi_.begin() + sample_num * num_groups_ * num_components_); + beta_.erase(beta_.begin() + sample_num * num_groups_ * num_components_); } } @@ -357,7 +357,7 @@ void RandomEffectsContainer::DeleteSample(int sample_num){ // ---------------------------- // This code works as with alpha for (int i = 0; i < num_components_; i++) { - sigma_xi_.erase(sigma_xi_.begin() + sample_num*num_components_); + sigma_xi_.erase(sigma_xi_.begin() + sample_num * num_components_); } } @@ -375,13 +375,13 @@ void RandomEffectsContainer::from_json(const nlohmann::json& rfx_container_json) this->num_samples_ = rfx_container_json.at("num_samples"); this->num_components_ = rfx_container_json.at("num_components"); this->num_groups_ = rfx_container_json.at("num_groups"); - + // Unpack beta and xi for (int i = 0; i < beta_size; i++) { beta_.push_back(rfx_container_json.at("beta").at(i)); xi_.push_back(rfx_container_json.at("xi").at(i)); } - + // Unpack alpha and sigma_xi for (int i = 0; i < alpha_size; i++) { alpha_.push_back(rfx_container_json.at("alpha").at(i)); @@ -392,19 +392,19 @@ void RandomEffectsContainer::from_json(const nlohmann::json& rfx_container_json) void RandomEffectsContainer::append_from_json(const nlohmann::json& rfx_container_json) { CHECK_EQ(this->num_components_, rfx_container_json.at("num_components")); CHECK_EQ(this->num_groups_, rfx_container_json.at("num_groups")); - + // Update internal sample count and extract size of parameter vectors int new_num_samples = rfx_container_json.at("num_samples"); this->num_samples_ += new_num_samples; int beta_size = rfx_container_json.at("beta_size"); int alpha_size = rfx_container_json.at("alpha_size"); - + // Unpack beta and xi for (int i = 0; i < beta_size; i++) { beta_.push_back(rfx_container_json.at("beta").at(i)); xi_.push_back(rfx_container_json.at("xi").at(i)); } - + // Unpack alpha and sigma_xi for (int i = 0; i < alpha_size; i++) { alpha_.push_back(rfx_container_json.at("alpha").at(i)); diff --git a/src/sampler.cpp b/src/sampler.cpp index f356d968..84df1cd6 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -24,76 +24,81 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer feature_types_(feature_types.size()); - for (int i = 0; i < feature_types.size(); i++) { - feature_types_[i] = static_cast(feature_types[i]); - } - - // Unpack sweep indices - std::vector sweep_indices_(sweep_indices.size()); - // if (sweep_indices.size() > 0) { - // sweep_indices_.resize(sweep_indices.size()); - for (int i = 0; i < sweep_indices.size(); i++) { - sweep_indices_[i] = sweep_indices[i]; - } - // } - - // Convert leaf model type to enum - StochTree::ModelType model_type; - if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; - else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; - else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; - else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; - else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; - else StochTree::Log::Fatal("Invalid model type"); - - // Unpack leaf model parameters - double leaf_scale; - Eigen::MatrixXd leaf_scale_matrix; - if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || - (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian)) { - leaf_scale = leaf_model_scale_input(0,0); - } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - int num_row = leaf_model_scale_input.nrow(); - int num_col = leaf_model_scale_input.ncol(); - leaf_scale_matrix.resize(num_row, num_col); - for (int i = 0; i < num_row; i++) { - for (int j = 0; j < num_col; j++) { - leaf_scale_matrix(i,j) = leaf_model_scale_input(i,j); - } - } - } - - // Convert variable weights to std::vector - std::vector var_weights_vector(variable_weights.size()); - for (int i = 0; i < variable_weights.size(); i++) { - var_weights_vector[i] = variable_weights[i]; - } - - // Prepare the samplers - StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); - int num_basis = data->NumBasis(); - - // Run one iteration of the sampler - if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); - } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); - } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); - } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); - } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + bool keep_forest, int num_features_subsample, + int num_threads) { + // Refactoring completely out of the R interface. + // Intention to refactor out of the C++ interface in the future. + bool pre_initialized = true; + + // Unpack feature types + std::vector feature_types_(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_[i] = static_cast(feature_types[i]); + } + + // Unpack sweep indices + std::vector sweep_indices_(sweep_indices.size()); + // if (sweep_indices.size() > 0) { + // sweep_indices_.resize(sweep_indices.size()); + for (int i = 0; i < sweep_indices.size(); i++) { + sweep_indices_[i] = sweep_indices[i]; + } + // } + + // Convert leaf model type to enum + StochTree::ModelType model_type; + if (leaf_model_int == 0) + model_type = StochTree::ModelType::kConstantLeafGaussian; + else if (leaf_model_int == 1) + model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; + else if (leaf_model_int == 2) + model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; + else if (leaf_model_int == 3) + model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) + model_type = StochTree::ModelType::kCloglogOrdinal; + else + StochTree::Log::Fatal("Invalid model type"); + + // Unpack leaf model parameters + double leaf_scale; + Eigen::MatrixXd leaf_scale_matrix; + if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || + (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian)) { + leaf_scale = leaf_model_scale_input(0, 0); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + int num_row = leaf_model_scale_input.nrow(); + int num_col = leaf_model_scale_input.ncol(); + leaf_scale_matrix.resize(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + leaf_scale_matrix(i, j) = leaf_model_scale_input(i, j); + } } + } + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Prepare the samplers + StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); + int num_basis = data->NumBasis(); + + // Run one iteration of the sampler + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); + } else if (model_type == StochTree::ModelType::kLogLinearVariance) { + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + } } [[cpp11::register]] @@ -110,234 +115,235 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer feature_types_(feature_types.size()); - for (int i = 0; i < feature_types.size(); i++) { - feature_types_[i] = static_cast(feature_types[i]); - } - - // Unpack sweep indices - std::vector sweep_indices_; - if (sweep_indices.size() > 0) { - sweep_indices_.resize(sweep_indices.size()); - for (int i = 0; i < sweep_indices.size(); i++) { - sweep_indices_[i] = sweep_indices[i]; - } - } - - // Convert leaf model type to enum - StochTree::ModelType model_type; - if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; - else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; - else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; - else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; - else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; - else StochTree::Log::Fatal("Invalid model type"); - - // Unpack leaf model parameters - double leaf_scale; - Eigen::MatrixXd leaf_scale_matrix; - if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || - (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian)) { - leaf_scale = leaf_model_scale_input(0,0); - } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - int num_row = leaf_model_scale_input.nrow(); - int num_col = leaf_model_scale_input.ncol(); - leaf_scale_matrix.resize(num_row, num_col); - for (int i = 0; i < num_row; i++) { - for (int j = 0; j < num_col; j++) { - leaf_scale_matrix(i,j) = leaf_model_scale_input(i,j); - } - } - } - - // Convert variable weights to std::vector - std::vector var_weights_vector(variable_weights.size()); - for (int i = 0; i < variable_weights.size(); i++) { - var_weights_vector[i] = variable_weights[i]; + bool keep_forest, int num_threads) { + // Refactoring completely out of the R interface. + // Intention to refactor out of the C++ interface in the future. + bool pre_initialized = true; + + // Unpack feature types + std::vector feature_types_(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_[i] = static_cast(feature_types[i]); + } + + // Unpack sweep indices + std::vector sweep_indices_; + if (sweep_indices.size() > 0) { + sweep_indices_.resize(sweep_indices.size()); + for (int i = 0; i < sweep_indices.size(); i++) { + sweep_indices_[i] = sweep_indices[i]; } - - // Prepare the samplers - StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); - int num_basis = data->NumBasis(); - - // Run one iteration of the sampler - if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); - } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); - } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); - } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); - } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + } + + // Convert leaf model type to enum + StochTree::ModelType model_type; + if (leaf_model_int == 0) + model_type = StochTree::ModelType::kConstantLeafGaussian; + else if (leaf_model_int == 1) + model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; + else if (leaf_model_int == 2) + model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; + else if (leaf_model_int == 3) + model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) + model_type = StochTree::ModelType::kCloglogOrdinal; + else + StochTree::Log::Fatal("Invalid model type"); + + // Unpack leaf model parameters + double leaf_scale; + Eigen::MatrixXd leaf_scale_matrix; + if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || + (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian)) { + leaf_scale = leaf_model_scale_input(0, 0); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + int num_row = leaf_model_scale_input.nrow(); + int num_col = leaf_model_scale_input.ncol(); + leaf_scale_matrix.resize(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + leaf_scale_matrix(i, j) = leaf_model_scale_input(i, j); + } } + } + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Prepare the samplers + StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); + int num_basis = data->NumBasis(); + + // Run one iteration of the sampler + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); + } else if (model_type == StochTree::ModelType::kLogLinearVariance) { + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + } } [[cpp11::register]] double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, cpp11::external_pointer dataset, cpp11::external_pointer rng, - double a, double b -) { - // Run one iteration of the sampler - StochTree::GlobalHomoskedasticVarianceModel var_model = StochTree::GlobalHomoskedasticVarianceModel(); - if (dataset->HasVarWeights()) { - return var_model.SampleVarianceParameter(residual->GetData(), dataset->GetVarWeights(), a, b, *rng); - } else { - return var_model.SampleVarianceParameter(residual->GetData(), a, b, *rng); - } + double a, double b) { + // Run one iteration of the sampler + StochTree::GlobalHomoskedasticVarianceModel var_model = StochTree::GlobalHomoskedasticVarianceModel(); + if (dataset->HasVarWeights()) { + return var_model.SampleVarianceParameter(residual->GetData(), dataset->GetVarWeights(), a, b, *rng); + } else { + return var_model.SampleVarianceParameter(residual->GetData(), a, b, *rng); + } } [[cpp11::register]] double sample_tau_one_iteration_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer rng, - double a, double b -) { - // Run one iteration of the sampler - StochTree::LeafNodeHomoskedasticVarianceModel var_model = StochTree::LeafNodeHomoskedasticVarianceModel(); - return var_model.SampleVarianceParameter(active_forest.get(), a, b, *rng); + double a, double b) { + // Run one iteration of the sampler + StochTree::LeafNodeHomoskedasticVarianceModel var_model = StochTree::LeafNodeHomoskedasticVarianceModel(); + return var_model.SampleVarianceParameter(active_forest.get(), a, b, *rng); } [[cpp11::register]] cpp11::external_pointer rng_cpp(int random_seed = -1) { - std::unique_ptr rng_; - if (random_seed == -1) { - std::random_device rd; - rng_ = std::make_unique(rd()); - } else { - rng_ = std::make_unique(random_seed); - } + std::unique_ptr rng_; + if (random_seed == -1) { + std::random_device rd; + rng_ = std::make_unique(rd()); + } else { + rng_ = std::make_unique(random_seed); + } - // Release management of the pointer to R session - return cpp11::external_pointer(rng_.release()); + // Release management of the pointer to R session + return cpp11::external_pointer(rng_.release()); } [[cpp11::register]] cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf, int max_depth = -1) { - // Create smart pointer to newly allocated object - std::unique_ptr prior_ptr_ = std::make_unique(alpha, beta, min_samples_leaf, max_depth); + // Create smart pointer to newly allocated object + std::unique_ptr prior_ptr_ = std::make_unique(alpha, beta, min_samples_leaf, max_depth); - // Release management of the pointer to R session - return cpp11::external_pointer(prior_ptr_.release()); + // Release management of the pointer to R session + return cpp11::external_pointer(prior_ptr_.release()); } [[cpp11::register]] void update_alpha_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, double alpha) { - // Update alpha - tree_prior_ptr->SetAlpha(alpha); + // Update alpha + tree_prior_ptr->SetAlpha(alpha); } [[cpp11::register]] void update_beta_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, double beta) { - // Update beta - tree_prior_ptr->SetBeta(beta); + // Update beta + tree_prior_ptr->SetBeta(beta); } [[cpp11::register]] void update_min_samples_leaf_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, int min_samples_leaf) { - // Update min_samples_leaf - tree_prior_ptr->SetMinSamplesLeaf(min_samples_leaf); + // Update min_samples_leaf + tree_prior_ptr->SetMinSamplesLeaf(min_samples_leaf); } [[cpp11::register]] void update_max_depth_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, int max_depth) { - // Update max_depth - tree_prior_ptr->SetMaxDepth(max_depth); + // Update max_depth + tree_prior_ptr->SetMaxDepth(max_depth); } [[cpp11::register]] double get_alpha_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr) { - return tree_prior_ptr->GetAlpha(); + return tree_prior_ptr->GetAlpha(); } [[cpp11::register]] double get_beta_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr) { - // Update beta - return tree_prior_ptr->GetBeta(); + // Update beta + return tree_prior_ptr->GetBeta(); } [[cpp11::register]] int get_min_samples_leaf_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr) { - return tree_prior_ptr->GetMinSamplesLeaf(); + return tree_prior_ptr->GetMinSamplesLeaf(); } [[cpp11::register]] int get_max_depth_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr) { - return tree_prior_ptr->GetMaxDepth(); + return tree_prior_ptr->GetMaxDepth(); } [[cpp11::register]] cpp11::external_pointer forest_tracker_cpp(cpp11::external_pointer data, cpp11::integers feature_types, int num_trees, StochTree::data_size_t n) { - // Convert vector of integers to std::vector of enum FeatureType - std::vector feature_types_(feature_types.size()); - for (int i = 0; i < feature_types.size(); i++) { - feature_types_[i] = static_cast(feature_types[i]); - } + // Convert vector of integers to std::vector of enum FeatureType + std::vector feature_types_(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_[i] = static_cast(feature_types[i]); + } - // Create smart pointer to newly allocated object - std::unique_ptr tracker_ptr_ = std::make_unique(data->GetCovariates(), feature_types_, num_trees, n); + // Create smart pointer to newly allocated object + std::unique_ptr tracker_ptr_ = std::make_unique(data->GetCovariates(), feature_types_, num_trees, n); - // Release management of the pointer to R session - return cpp11::external_pointer(tracker_ptr_.release()); + // Release management of the pointer to R session + return cpp11::external_pointer(tracker_ptr_.release()); } [[cpp11::register]] cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer tracker_ptr) { - int n_train = tracker_ptr->GetNumObservations(); - cpp11::writable::doubles output(n_train); - for (int i = 0; i < n_train; i++) { - output[i] = tracker_ptr->GetSamplePrediction(i); - } - return output; + int n_train = tracker_ptr->GetNumObservations(); + cpp11::writable::doubles output(n_train); + for (int i = 0; i < n_train; i++) { + output[i] = tracker_ptr->GetSamplePrediction(i); + } + return output; } [[cpp11::register]] cpp11::writable::integers sample_without_replacement_integer_cpp( cpp11::integers population_vector, cpp11::doubles sampling_probs, - int sample_size -) { - // Unpack pointer to population vector - int population_size = population_vector.size(); - int* population_vector_ptr = INTEGER(PROTECT(population_vector)); + int sample_size) { + // Unpack pointer to population vector + int population_size = population_vector.size(); + int* population_vector_ptr = INTEGER(PROTECT(population_vector)); - // Unpack pointer to sampling probabilities - double* sampling_probs_ptr = REAL(PROTECT(sampling_probs)); + // Unpack pointer to sampling probabilities + double* sampling_probs_ptr = REAL(PROTECT(sampling_probs)); - // Create output vector - cpp11::writable::integers output(sample_size); + // Create output vector + cpp11::writable::integers output(sample_size); - // Unpack pointer to output vector - int* output_ptr = INTEGER(PROTECT(output)); + // Unpack pointer to output vector + int* output_ptr = INTEGER(PROTECT(output)); - // Create C++ RNG - std::random_device rd; - std::mt19937 gen(rd()); + // Create C++ RNG + std::random_device rd; + std::mt19937 gen(rd()); - // Run the sampler - StochTree::sample_without_replacement( - output_ptr, sampling_probs_ptr, population_vector_ptr, population_size, sample_size, gen - ); + // Run the sampler + StochTree::sample_without_replacement( + output_ptr, sampling_probs_ptr, population_vector_ptr, population_size, sample_size, gen); - // Unprotect raw pointers - UNPROTECT(3); + // Unprotect raw pointers + UNPROTECT(3); - // Return result - return(output); + // Return result + return (output); } [[cpp11::register]] cpp11::external_pointer ordinal_sampler_cpp() { - std::unique_ptr sampler_ptr = std::make_unique(); - return cpp11::external_pointer(sampler_ptr.release()); + std::unique_ptr sampler_ptr = std::make_unique(); + return cpp11::external_pointer(sampler_ptr.release()); } [[cpp11::register]] @@ -345,9 +351,8 @@ void ordinal_sampler_update_latent_variables_cpp( cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, - cpp11::external_pointer rng_ptr -) { - sampler_ptr->UpdateLatentVariables(*data_ptr, outcome_ptr->GetData(), *rng_ptr); + cpp11::external_pointer rng_ptr) { + sampler_ptr->UpdateLatentVariables(*data_ptr, outcome_ptr->GetData(), *rng_ptr); } [[cpp11::register]] @@ -358,15 +363,13 @@ void ordinal_sampler_update_gamma_params_cpp( double alpha_gamma, double beta_gamma, double gamma_0, - cpp11::external_pointer rng_ptr -) { - sampler_ptr->UpdateGammaParams(*data_ptr, outcome_ptr->GetData(), alpha_gamma, beta_gamma, gamma_0, *rng_ptr); + cpp11::external_pointer rng_ptr) { + sampler_ptr->UpdateGammaParams(*data_ptr, outcome_ptr->GetData(), alpha_gamma, beta_gamma, gamma_0, *rng_ptr); } [[cpp11::register]] void ordinal_sampler_update_cumsum_exp_cpp( cpp11::external_pointer sampler_ptr, - cpp11::external_pointer data_ptr -) { - sampler_ptr->UpdateCumulativeExpSums(*data_ptr); + cpp11::external_pointer data_ptr) { + sampler_ptr->UpdateCumulativeExpSums(*data_ptr); } diff --git a/src/serialization.cpp b/src/serialization.cpp index fb248f62..6c779e43 100644 --- a/src/serialization.cpp +++ b/src/serialization.cpp @@ -11,405 +11,405 @@ [[cpp11::register]] cpp11::external_pointer init_json_cpp() { - std::unique_ptr json_ptr = std::make_unique(); - json forests = nlohmann::json::object(); - json rfx = nlohmann::json::object(); - json parameters = nlohmann::json::object(); - json_ptr->emplace("forests", forests); - json_ptr->emplace("random_effects", rfx); - json_ptr->emplace("num_forests", 0); - json_ptr->emplace("num_random_effects", 0); - return cpp11::external_pointer(json_ptr.release()); + std::unique_ptr json_ptr = std::make_unique(); + json forests = nlohmann::json::object(); + json rfx = nlohmann::json::object(); + json parameters = nlohmann::json::object(); + json_ptr->emplace("forests", forests); + json_ptr->emplace("random_effects", rfx); + json_ptr->emplace("num_forests", 0); + json_ptr->emplace("num_random_effects", 0); + return cpp11::external_pointer(json_ptr.release()); } [[cpp11::register]] void json_add_double_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name, double field_value) { - if (json_ptr->contains(subfolder_name)) { - if (json_ptr->at(subfolder_name).contains(field_name)) { - json_ptr->at(subfolder_name).at(field_name) = field_value; - } else { - json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); - } + if (json_ptr->contains(subfolder_name)) { + if (json_ptr->at(subfolder_name).contains(field_name)) { + json_ptr->at(subfolder_name).at(field_name) = field_value; } else { - json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); - json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); } + } else { + json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); + } } [[cpp11::register]] void json_add_double_cpp(cpp11::external_pointer json_ptr, std::string field_name, double field_value) { - if (json_ptr->contains(field_name)) { - json_ptr->at(field_name) = field_value; - } else { - json_ptr->emplace(std::pair(field_name, field_value)); - } + if (json_ptr->contains(field_name)) { + json_ptr->at(field_name) = field_value; + } else { + json_ptr->emplace(std::pair(field_name, field_value)); + } } [[cpp11::register]] void json_add_integer_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name, int field_value) { - if (json_ptr->contains(subfolder_name)) { - if (json_ptr->at(subfolder_name).contains(field_name)) { - json_ptr->at(subfolder_name).at(field_name) = field_value; - } else { - json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); - } + if (json_ptr->contains(subfolder_name)) { + if (json_ptr->at(subfolder_name).contains(field_name)) { + json_ptr->at(subfolder_name).at(field_name) = field_value; } else { - json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); - json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); } + } else { + json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); + } } [[cpp11::register]] void json_add_integer_cpp(cpp11::external_pointer json_ptr, std::string field_name, int field_value) { - if (json_ptr->contains(field_name)) { - json_ptr->at(field_name) = field_value; - } else { - json_ptr->emplace(std::pair(field_name, field_value)); - } + if (json_ptr->contains(field_name)) { + json_ptr->at(field_name) = field_value; + } else { + json_ptr->emplace(std::pair(field_name, field_value)); + } } [[cpp11::register]] void json_add_bool_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name, bool field_value) { - if (json_ptr->contains(subfolder_name)) { - if (json_ptr->at(subfolder_name).contains(field_name)) { - json_ptr->at(subfolder_name).at(field_name) = field_value; - } else { - json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); - } + if (json_ptr->contains(subfolder_name)) { + if (json_ptr->at(subfolder_name).contains(field_name)) { + json_ptr->at(subfolder_name).at(field_name) = field_value; } else { - json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); - json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); } + } else { + json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); + } } [[cpp11::register]] void json_add_bool_cpp(cpp11::external_pointer json_ptr, std::string field_name, bool field_value) { - if (json_ptr->contains(field_name)) { - json_ptr->at(field_name) = field_value; - } else { - json_ptr->emplace(std::pair(field_name, field_value)); - } + if (json_ptr->contains(field_name)) { + json_ptr->at(field_name) = field_value; + } else { + json_ptr->emplace(std::pair(field_name, field_value)); + } } [[cpp11::register]] void json_add_vector_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name, cpp11::doubles field_vector) { - int vec_length = field_vector.size(); - if (json_ptr->contains(subfolder_name)) { - if (json_ptr->at(subfolder_name).contains(field_name)) { - json_ptr->at(subfolder_name).at(field_name).clear(); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } - } else { - json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } - } + int vec_length = field_vector.size(); + if (json_ptr->contains(subfolder_name)) { + if (json_ptr->at(subfolder_name).contains(field_name)) { + json_ptr->at(subfolder_name).at(field_name).clear(); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); + } } else { - json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); - json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } + json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); + } + } + } else { + json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); } + } } [[cpp11::register]] void json_add_vector_cpp(cpp11::external_pointer json_ptr, std::string field_name, cpp11::doubles field_vector) { - int vec_length = field_vector.size(); - if (json_ptr->contains(field_name)) { - json_ptr->at(field_name).clear(); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(field_name).emplace_back(field_vector.at(i)); - } - } else { - json_ptr->emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(field_name).emplace_back(field_vector.at(i)); - } + int vec_length = field_vector.size(); + if (json_ptr->contains(field_name)) { + json_ptr->at(field_name).clear(); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(field_name).emplace_back(field_vector.at(i)); } + } else { + json_ptr->emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(field_name).emplace_back(field_vector.at(i)); + } + } } [[cpp11::register]] void json_add_integer_vector_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name, cpp11::integers field_vector) { - int vec_length = field_vector.size(); - if (json_ptr->contains(subfolder_name)) { - if (json_ptr->at(subfolder_name).contains(field_name)) { - json_ptr->at(subfolder_name).at(field_name).clear(); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } - } else { - json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } - } + int vec_length = field_vector.size(); + if (json_ptr->contains(subfolder_name)) { + if (json_ptr->at(subfolder_name).contains(field_name)) { + json_ptr->at(subfolder_name).at(field_name).clear(); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); + } } else { - json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); - json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } + json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); + } + } + } else { + json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); } + } } [[cpp11::register]] void json_add_integer_vector_cpp(cpp11::external_pointer json_ptr, std::string field_name, cpp11::integers field_vector) { - int vec_length = field_vector.size(); - if (json_ptr->contains(field_name)) { - json_ptr->at(field_name).clear(); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(field_name).emplace_back(field_vector.at(i)); - } - } else { - json_ptr->emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(field_name).emplace_back(field_vector.at(i)); - } + int vec_length = field_vector.size(); + if (json_ptr->contains(field_name)) { + json_ptr->at(field_name).clear(); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(field_name).emplace_back(field_vector.at(i)); + } + } else { + json_ptr->emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(field_name).emplace_back(field_vector.at(i)); } + } } [[cpp11::register]] void json_add_string_vector_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name, cpp11::strings field_vector) { - int vec_length = field_vector.size(); - if (json_ptr->contains(subfolder_name)) { - if (json_ptr->at(subfolder_name).contains(field_name)) { - json_ptr->at(subfolder_name).at(field_name).clear(); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } - } else { - json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } - } + int vec_length = field_vector.size(); + if (json_ptr->contains(subfolder_name)) { + if (json_ptr->at(subfolder_name).contains(field_name)) { + json_ptr->at(subfolder_name).at(field_name).clear(); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); + } } else { - json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); - json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); - } + json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); + } + } + } else { + json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(subfolder_name).at(field_name).emplace_back(field_vector.at(i)); } + } } [[cpp11::register]] void json_add_string_vector_cpp(cpp11::external_pointer json_ptr, std::string field_name, cpp11::strings field_vector) { - int vec_length = field_vector.size(); - if (json_ptr->contains(field_name)) { - json_ptr->at(field_name).clear(); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(field_name).emplace_back(field_vector.at(i)); - } - } else { - json_ptr->emplace(std::pair(field_name, nlohmann::json::array())); - for (int i = 0; i < vec_length; i++) { - json_ptr->at(field_name).emplace_back(field_vector.at(i)); - } + int vec_length = field_vector.size(); + if (json_ptr->contains(field_name)) { + json_ptr->at(field_name).clear(); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(field_name).emplace_back(field_vector.at(i)); + } + } else { + json_ptr->emplace(std::pair(field_name, nlohmann::json::array())); + for (int i = 0; i < vec_length; i++) { + json_ptr->at(field_name).emplace_back(field_vector.at(i)); } + } } [[cpp11::register]] void json_add_string_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name, std::string field_value) { - if (json_ptr->contains(subfolder_name)) { - if (json_ptr->at(subfolder_name).contains(field_name)) { - json_ptr->at(subfolder_name).at(field_name) = field_value; - } else { - json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); - } + if (json_ptr->contains(subfolder_name)) { + if (json_ptr->at(subfolder_name).contains(field_name)) { + json_ptr->at(subfolder_name).at(field_name) = field_value; } else { - json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); - json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); } + } else { + json_ptr->emplace(std::pair(subfolder_name, nlohmann::json::object())); + json_ptr->at(subfolder_name).emplace(std::pair(field_name, field_value)); + } } [[cpp11::register]] void json_add_string_cpp(cpp11::external_pointer json_ptr, std::string field_name, std::string field_value) { - if (json_ptr->contains(field_name)) { - json_ptr->at(field_name) = field_value; - } else { - json_ptr->emplace(std::pair(field_name, field_value)); - } + if (json_ptr->contains(field_name)) { + json_ptr->at(field_name) = field_value; + } else { + json_ptr->emplace(std::pair(field_name, field_value)); + } } [[cpp11::register]] bool json_contains_field_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name) { - if (json_ptr->contains(subfolder_name)) { - if (json_ptr->at(subfolder_name).contains(field_name)) { - return true; - } else { - return false; - } + if (json_ptr->contains(subfolder_name)) { + if (json_ptr->at(subfolder_name).contains(field_name)) { + return true; } else { - return false; + return false; } + } else { + return false; + } } [[cpp11::register]] bool json_contains_field_cpp(cpp11::external_pointer json_ptr, std::string field_name) { - if (json_ptr->contains(field_name)) { - return true; - } else { - return false; - } + if (json_ptr->contains(field_name)) { + return true; + } else { + return false; + } } [[cpp11::register]] double json_extract_double_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name) { - return json_ptr->at(subfolder_name).at(field_name); + return json_ptr->at(subfolder_name).at(field_name); } [[cpp11::register]] double json_extract_double_cpp(cpp11::external_pointer json_ptr, std::string field_name) { - return json_ptr->at(field_name); + return json_ptr->at(field_name); } [[cpp11::register]] int json_extract_integer_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name) { - return json_ptr->at(subfolder_name).at(field_name); + return json_ptr->at(subfolder_name).at(field_name); } [[cpp11::register]] int json_extract_integer_cpp(cpp11::external_pointer json_ptr, std::string field_name) { - return json_ptr->at(field_name); + return json_ptr->at(field_name); } [[cpp11::register]] bool json_extract_bool_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name) { - return json_ptr->at(subfolder_name).at(field_name); + return json_ptr->at(subfolder_name).at(field_name); } [[cpp11::register]] bool json_extract_bool_cpp(cpp11::external_pointer json_ptr, std::string field_name) { - return json_ptr->at(field_name); + return json_ptr->at(field_name); } [[cpp11::register]] std::string json_extract_string_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name) { - return json_ptr->at(subfolder_name).at(field_name); + return json_ptr->at(subfolder_name).at(field_name); } [[cpp11::register]] std::string json_extract_string_cpp(cpp11::external_pointer json_ptr, std::string field_name) { - return json_ptr->at(field_name); + return json_ptr->at(field_name); } [[cpp11::register]] cpp11::writable::doubles json_extract_vector_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name) { - cpp11::writable::doubles output; - int vec_length = json_ptr->at(subfolder_name).at(field_name).size(); - for (int i = 0; i < vec_length; i++) output.push_back((json_ptr->at(subfolder_name).at(field_name).at(i))); - return output; + cpp11::writable::doubles output; + int vec_length = json_ptr->at(subfolder_name).at(field_name).size(); + for (int i = 0; i < vec_length; i++) output.push_back((json_ptr->at(subfolder_name).at(field_name).at(i))); + return output; } [[cpp11::register]] cpp11::writable::doubles json_extract_vector_cpp(cpp11::external_pointer json_ptr, std::string field_name) { - cpp11::writable::doubles output; - int vec_length = json_ptr->at(field_name).size(); - for (int i = 0; i < vec_length; i++) output.push_back((json_ptr->at(field_name).at(i))); - return output; + cpp11::writable::doubles output; + int vec_length = json_ptr->at(field_name).size(); + for (int i = 0; i < vec_length; i++) output.push_back((json_ptr->at(field_name).at(i))); + return output; } [[cpp11::register]] cpp11::writable::integers json_extract_integer_vector_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name) { - cpp11::writable::integers output; - int vec_length = json_ptr->at(subfolder_name).at(field_name).size(); - for (int i = 0; i < vec_length; i++) output.push_back((json_ptr->at(subfolder_name).at(field_name).at(i))); - return output; + cpp11::writable::integers output; + int vec_length = json_ptr->at(subfolder_name).at(field_name).size(); + for (int i = 0; i < vec_length; i++) output.push_back((json_ptr->at(subfolder_name).at(field_name).at(i))); + return output; } [[cpp11::register]] cpp11::writable::integers json_extract_integer_vector_cpp(cpp11::external_pointer json_ptr, std::string field_name) { - cpp11::writable::integers output; - int vec_length = json_ptr->at(field_name).size(); - for (int i = 0; i < vec_length; i++) output.push_back((json_ptr->at(field_name).at(i))); - return output; + cpp11::writable::integers output; + int vec_length = json_ptr->at(field_name).size(); + for (int i = 0; i < vec_length; i++) output.push_back((json_ptr->at(field_name).at(i))); + return output; } [[cpp11::register]] cpp11::writable::strings json_extract_string_vector_subfolder_cpp(cpp11::external_pointer json_ptr, std::string subfolder_name, std::string field_name) { - int vec_length = json_ptr->at(subfolder_name).at(field_name).size(); - std::vector output(vec_length); - for (int i = 0; i < vec_length; i++) output.at(i) = json_ptr->at(subfolder_name).at(field_name).at(i); - return output; + int vec_length = json_ptr->at(subfolder_name).at(field_name).size(); + std::vector output(vec_length); + for (int i = 0; i < vec_length; i++) output.at(i) = json_ptr->at(subfolder_name).at(field_name).at(i); + return output; } [[cpp11::register]] cpp11::writable::strings json_extract_string_vector_cpp(cpp11::external_pointer json_ptr, std::string field_name) { - int vec_length = json_ptr->at(field_name).size(); - std::vector output(vec_length); - for (int i = 0; i < vec_length; i++) output.at(i) = json_ptr->at(field_name).at(i); - return output; + int vec_length = json_ptr->at(field_name).size(); + std::vector output(vec_length); + for (int i = 0; i < vec_length; i++) output.at(i) = json_ptr->at(field_name).at(i); + return output; } [[cpp11::register]] std::string json_add_forest_cpp(cpp11::external_pointer json_ptr, cpp11::external_pointer forest_samples) { - int forest_num = json_ptr->at("num_forests"); - std::string forest_label = "forest_" + std::to_string(forest_num); - nlohmann::json forest_json = forest_samples->to_json(); - json_ptr->at("forests").emplace(forest_label, forest_json); - json_ptr->at("num_forests") = forest_num + 1; - return forest_label; + int forest_num = json_ptr->at("num_forests"); + std::string forest_label = "forest_" + std::to_string(forest_num); + nlohmann::json forest_json = forest_samples->to_json(); + json_ptr->at("forests").emplace(forest_label, forest_json); + json_ptr->at("num_forests") = forest_num + 1; + return forest_label; } [[cpp11::register]] void json_increment_rfx_count_cpp(cpp11::external_pointer json_ptr) { - int rfx_num = json_ptr->at("num_random_effects"); - json_ptr->at("num_random_effects") = rfx_num + 1; + int rfx_num = json_ptr->at("num_random_effects"); + json_ptr->at("num_random_effects") = rfx_num + 1; } [[cpp11::register]] std::string json_add_rfx_container_cpp(cpp11::external_pointer json_ptr, cpp11::external_pointer rfx_samples) { - int rfx_num = json_ptr->at("num_random_effects"); - std::string rfx_label = "random_effect_container_" + std::to_string(rfx_num); - nlohmann::json rfx_json = rfx_samples->to_json(); - json_ptr->at("random_effects").emplace(rfx_label, rfx_json); - return rfx_label; + int rfx_num = json_ptr->at("num_random_effects"); + std::string rfx_label = "random_effect_container_" + std::to_string(rfx_num); + nlohmann::json rfx_json = rfx_samples->to_json(); + json_ptr->at("random_effects").emplace(rfx_label, rfx_json); + return rfx_label; } [[cpp11::register]] std::string json_add_rfx_label_mapper_cpp(cpp11::external_pointer json_ptr, cpp11::external_pointer label_mapper) { - int rfx_num = json_ptr->at("num_random_effects"); - std::string rfx_label = "random_effect_label_mapper_" + std::to_string(rfx_num); - nlohmann::json rfx_json = label_mapper->to_json(); - json_ptr->at("random_effects").emplace(rfx_label, rfx_json); - return rfx_label; + int rfx_num = json_ptr->at("num_random_effects"); + std::string rfx_label = "random_effect_label_mapper_" + std::to_string(rfx_num); + nlohmann::json rfx_json = label_mapper->to_json(); + json_ptr->at("random_effects").emplace(rfx_label, rfx_json); + return rfx_label; } [[cpp11::register]] std::string json_add_rfx_groupids_cpp(cpp11::external_pointer json_ptr, cpp11::integers groupids) { - int rfx_num = json_ptr->at("num_random_effects"); - std::string rfx_label = "random_effect_groupids_" + std::to_string(rfx_num); - nlohmann::json groupids_json = nlohmann::json::array(); - for (int i = 0; i < groupids.size(); i++) { - groupids_json.emplace_back(groupids.at(i)); - } - json_ptr->at("random_effects").emplace(rfx_label, groupids_json); - return rfx_label; + int rfx_num = json_ptr->at("num_random_effects"); + std::string rfx_label = "random_effect_groupids_" + std::to_string(rfx_num); + nlohmann::json groupids_json = nlohmann::json::array(); + for (int i = 0; i < groupids.size(); i++) { + groupids_json.emplace_back(groupids.at(i)); + } + json_ptr->at("random_effects").emplace(rfx_label, groupids_json); + return rfx_label; } [[cpp11::register]] std::string get_json_string_cpp(cpp11::external_pointer json_ptr) { - return json_ptr->dump(); + return json_ptr->dump(); } [[cpp11::register]] void json_save_file_cpp(cpp11::external_pointer json_ptr, std::string filename) { - std::ofstream output_file(filename); - output_file << *json_ptr << std::endl; + std::ofstream output_file(filename); + output_file << *json_ptr << std::endl; } [[cpp11::register]] void json_load_file_cpp(cpp11::external_pointer json_ptr, std::string filename) { - std::ifstream f(filename); - // nlohmann::json file_json = nlohmann::json::parse(f); - *json_ptr = nlohmann::json::parse(f); - // json_ptr.reset(&file_json); + std::ifstream f(filename); + // nlohmann::json file_json = nlohmann::json::parse(f); + *json_ptr = nlohmann::json::parse(f); + // json_ptr.reset(&file_json); } [[cpp11::register]] void json_load_string_cpp(cpp11::external_pointer json_ptr, std::string json_string) { - *json_ptr = nlohmann::json::parse(json_string); + *json_ptr = nlohmann::json::parse(json_string); } diff --git a/src/tree.cpp b/src/tree.cpp index 32c51475..43debc57 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -1,5 +1,5 @@ /*! - * Inspired by the design of the tree in the xgboost and treelite package, both released under the Apache license + * Inspired by the design of the tree in the xgboost and treelite package, both released under the Apache license * with the following copyright: * Copyright 2015-2023 by XGBoost Contributors * Copyright 2017-2021 by [treelite] Contributors @@ -24,14 +24,14 @@ std::int32_t Tree::NumLeafParents() const { } std::int32_t Tree::NumSplitNodes() const { - std::int32_t splits { 0 }; + std::int32_t splits{0}; auto const& self = *this; this->WalkTree([&splits, &self](std::int32_t nidx) { - if (!self.IsLeaf(nidx)){ - splits++; - } - return true; - }); + if (!self.IsLeaf(nidx)) { + splits++; + } + return true; + }); return splits; } @@ -110,10 +110,10 @@ std::int32_t Tree::AllocNode() { --num_deleted_nodes; return nid; } - + std::int32_t nd = num_nodes++; CHECK_LT(num_nodes, std::numeric_limits::max()); - + node_type_.push_back(TreeNodeType::kLeafNode); cleft_.push_back(kInvalidNodeId); cright_.push_back(kInvalidNodeId); @@ -121,7 +121,7 @@ std::int32_t Tree::AllocNode() { leaf_value_.push_back(static_cast(0)); threshold_.push_back(static_cast(0)); node_deleted_.push_back(false); - // THIS is a placeholder, currently set after AllocNode is called ... + // THIS is a placeholder, currently set after AllocNode is called ... // ... to be refactored ... parent_.push_back(static_cast(0)); @@ -169,7 +169,7 @@ void Tree::ExpandNode(std::int32_t nid, int split_index, double split_value, dou internal_nodes_.push_back(nid); // Remove nid's parent node (if applicable) from leaf parents - if (!IsRoot(nid)){ + if (!IsRoot(nid)) { std::int32_t parent_idx = Parent(nid); leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), parent_idx), leaf_parents_.end()); } @@ -195,7 +195,7 @@ void Tree::ExpandNode(std::int32_t nid, int split_index, std::vector& void Tree::PredictLeafIndexInplace(Eigen::MatrixXd& covariates, std::vector& output, int32_t offset, int32_t max_leaf) { int n = covariates.rows(); CHECK_GE(output.size(), offset + n); - std::map renumber_map; + std::map renumber_map; for (int i = 0; i < leaves_.size(); i++) { renumber_map.insert({leaves_[i], i}); } @@ -437,7 +437,7 @@ void Tree::PredictLeafIndexInplace(Eigen::MatrixXd& covariates, std::vector>& covariates, std::vector& output, int32_t offset, int32_t max_leaf) { int n = covariates.rows(); CHECK_GE(output.size(), offset + n); - std::map renumber_map; + std::map renumber_map; for (int i = 0; i < leaves_.size(); i++) { renumber_map.insert({leaves_[i], i}); } @@ -448,12 +448,12 @@ void Tree::PredictLeafIndexInplace(Eigen::Map>& covariates, - Eigen::Map>& output, +void Tree::PredictLeafIndexInplace(Eigen::Map>& covariates, + Eigen::Map>& output, int column_ind, int32_t offset, int32_t max_leaf) { int n = covariates.rows(); CHECK_GE(output.size(), offset + n); - std::map renumber_map; + std::map renumber_map; for (int i = 0; i < leaves_.size(); i++) { renumber_map.insert({leaves_[i], i}); } @@ -481,26 +481,26 @@ void TreeNodeVectorsToJson(json& obj, Tree* tree) { tree_array_map.emplace(std::pair("category_list_end", json::array())); // Extract only the non-deleted nodes into tree_array_map -// bool node_deleted; + // bool node_deleted; for (int i = 0; i < tree->NumNodes(); i++) { -// node_deleted = (std::find(tree->deleted_nodes_.begin(), tree->deleted_nodes_.end(), i) -// != tree->deleted_nodes_.end()); -// if (!node_deleted) { - tree_array_map["node_type"].emplace_back(static_cast(tree->node_type_[i])); - tree_array_map["parent"].emplace_back(tree->parent_[i]); - tree_array_map["left"].emplace_back(tree->cleft_[i]); - tree_array_map["right"].emplace_back(tree->cright_[i]); - tree_array_map["split_index"].emplace_back(tree->split_index_[i]); - tree_array_map["leaf_value"].emplace_back(tree->leaf_value_[i]); - tree_array_map["threshold"].emplace_back(tree->threshold_[i]); - tree_array_map["node_deleted"].emplace_back(tree->node_deleted_[i]); - tree_array_map["leaf_vector_begin"].emplace_back(static_cast(tree->leaf_vector_begin_[i])); - tree_array_map["leaf_vector_end"].emplace_back(static_cast(tree->leaf_vector_end_[i])); - tree_array_map["category_list_begin"].emplace_back(static_cast(tree->category_list_begin_[i])); - tree_array_map["category_list_end"].emplace_back(static_cast(tree->category_list_end_[i])); -// } - } - + // node_deleted = (std::find(tree->deleted_nodes_.begin(), tree->deleted_nodes_.end(), i) + // != tree->deleted_nodes_.end()); + // if (!node_deleted) { + tree_array_map["node_type"].emplace_back(static_cast(tree->node_type_[i])); + tree_array_map["parent"].emplace_back(tree->parent_[i]); + tree_array_map["left"].emplace_back(tree->cleft_[i]); + tree_array_map["right"].emplace_back(tree->cright_[i]); + tree_array_map["split_index"].emplace_back(tree->split_index_[i]); + tree_array_map["leaf_value"].emplace_back(tree->leaf_value_[i]); + tree_array_map["threshold"].emplace_back(tree->threshold_[i]); + tree_array_map["node_deleted"].emplace_back(tree->node_deleted_[i]); + tree_array_map["leaf_vector_begin"].emplace_back(static_cast(tree->leaf_vector_begin_[i])); + tree_array_map["leaf_vector_end"].emplace_back(static_cast(tree->leaf_vector_end_[i])); + tree_array_map["category_list_begin"].emplace_back(static_cast(tree->category_list_begin_[i])); + tree_array_map["category_list_end"].emplace_back(static_cast(tree->category_list_end_[i])); + // } + } + // Unpack the map into the reference JSON object for (auto& pair : tree_array_map) { obj.emplace(pair); @@ -532,7 +532,7 @@ void NodeListsToJson(json& obj, Tree* tree) { json vec_leaf_parents = json::array(); json vec_leaves = json::array(); json vec_deleted_nodes = json::array(); - + if (tree->internal_nodes_.size() > 0) { for (int i = 0; i < tree->internal_nodes_.size(); i++) { vec_internal_nodes.emplace_back(tree->internal_nodes_[i]); @@ -556,7 +556,7 @@ void NodeListsToJson(json& obj, Tree* tree) { vec_deleted_nodes.emplace_back(tree->deleted_nodes_[i]); } } - + obj.emplace("internal_nodes", vec_internal_nodes); obj.emplace("leaf_parents", vec_leaf_parents); obj.emplace("leaves", vec_leaves); @@ -577,7 +577,7 @@ json Tree::to_json() { MultivariateLeafVectorToJson(result_obj, this); SplitCategoryVectorToJson(result_obj, this); NodeListsToJson(result_obj, this); - + // Initialize Json from Json::object map and return result return result_obj; } @@ -603,8 +603,10 @@ void JsonToTreeNodeVectors(const json& tree_json, Tree* tree) { tree->cleft_.push_back(tree_json.at("left").at(i)); tree->cright_.push_back(tree_json.at("right").at(i)); tree->split_index_.push_back(tree_json.at("split_index").at(i)); - if (is_univariate) tree->leaf_value_.push_back(tree_json.at("leaf_value").at(i)); - else tree->leaf_value_.push_back(0.); + if (is_univariate) + tree->leaf_value_.push_back(tree_json.at("leaf_value").at(i)); + else + tree->leaf_value_.push_back(0.); tree->threshold_.push_back(tree_json.at("threshold").at(i)); tree->node_deleted_.push_back(tree_json.at("node_deleted").at(i)); // Handle type conversions for node_type, leaf_vector_begin/end, and category_list_begin/end @@ -665,7 +667,7 @@ void Tree::from_json(const json& tree_json) { tree_json.at("has_categorical_split").get_to(this->has_categorical_split_); tree_json.at("output_dimension").get_to(this->output_dimension_); tree_json.at("is_log_scale").get_to(this->is_log_scale_); - + // Unpack the array based fields JsonToTreeNodeVectors(tree_json, this); JsonToMultivariateLeafVector(tree_json, this); @@ -673,4 +675,4 @@ void Tree::from_json(const json& tree_json) { JsonToNodeLists(tree_json, this); } -} // namespace StochTree +} // namespace StochTree From b8e299d622fb423653b18bf59aee1f58b23b6677 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 12:22:16 -0400 Subject: [PATCH 06/64] Some clang-tidy low-hanging fruit fixes --- include/stochtree/data.h | 2 +- include/stochtree/ordinal_sampler.h | 1 - include/stochtree/partition_tracker.h | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/include/stochtree/data.h b/include/stochtree/data.h index e81dc17b..189595bd 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -119,7 +119,7 @@ static inline void FeatureUnpack(std::vector& categorical_variables, co static inline std::vector Str2FeatureVec(const char* parameters) { std::vector feature_vec; auto args = Common::Split(parameters, ","); - for (auto arg : args) { + for (const auto& arg : args) { FeatureUnpack(feature_vec, Common::Trim(arg).c_str()); } return feature_vec; diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h index b4269987..4d5777b8 100644 --- a/include/stochtree/ordinal_sampler.h +++ b/include/stochtree/ordinal_sampler.h @@ -13,7 +13,6 @@ #include #include -#include namespace StochTree { diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index a0247884..6f3a94ae 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -420,7 +420,7 @@ class UnsortedNodeSampleTracker { /*! \brief Update SampleNodeMapper for all the observations in tree */ void UpdateObservationMapping(Tree* tree, int tree_id, SampleNodeMapper* sample_node_mapper) { - std::vector leaves = tree->GetLeaves(); + std::vector const& leaves = tree->GetLeaves(); int leaf; for (int i = 0; i < leaves.size(); i++) { leaf = leaves[i]; From 638f658a0214dbd2acfc9544ff41f8915fe3bf1e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 12:39:10 -0400 Subject: [PATCH 07/64] Added `.clangd` and updated cmakepresets and `.gitignore` --- .clangd | 14 ++++++++++++++ .gitignore | 2 +- CMakePresets.json | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 .clangd diff --git a/.clangd b/.clangd new file mode 100644 index 00000000..c87746fa --- /dev/null +++ b/.clangd @@ -0,0 +1,14 @@ +--- +# Fallback compile flags for files not captured by compile_commands.json +# (e.g. R wrapper files when BUILD_PYTHON=OFF, or before cmake has been run). +# Files compiled by cmake use compile_commands.json and ignore these flags. +CompileFlags: + Add: + - "-std=c++17" + - "-Iinclude" + - "-Ideps/boost_math/include" + - "-Ideps/eigen" + - "-Ideps/fast_double_parser/include" + - "-Ideps/fmt/include" + - "-Ideps/pybind11/include" + Compiler: clang++ diff --git a/.gitignore b/.gitignore index 6acc4362..1fb03657 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,7 @@ cpp_docs/doxyoutput/xml cpp_docs/doxyoutput/latex stochtree_cran *.trace -*.clangd +.cache/clangd/ *.claude ## R gitignore diff --git a/CMakePresets.json b/CMakePresets.json index 42c7644c..a1134ab8 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -24,7 +24,7 @@ "USE_DEBUG": "ON", "BUILD_DEBUG_TARGETS": "ON", "BUILD_TEST": "OFF", - "BUILD_PYTHON": "OFF", + "BUILD_PYTHON": "ON", "CMAKE_EXPORT_COMPILE_COMMANDS": "ON" } }, From 00ddf31d81ed101426143ce1c3a2cd10826b75ef Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 13:08:52 -0400 Subject: [PATCH 08/64] Updated build infrastructure --- .vscode/tasks.json | 16 ++++++++-------- CMakeLists.txt | 31 +++++++++++-------------------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 585737d9..2ffeecb8 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -20,7 +20,7 @@ { "label": "CMake: Build (dev)", "type": "shell", - "command": "cmake --build --preset dev", + "command": "cmake --preset dev && cmake --build --preset dev", "group": { "kind": "build", "isDefault": true }, "problemMatcher": ["$gcc"], "presentation": { "reveal": "always", "panel": "shared" } @@ -28,7 +28,7 @@ { "label": "CMake: Build (dev-quick)", "type": "shell", - "command": "cmake --build --preset dev-quick", + "command": "cmake --preset dev-quick && cmake --build --preset dev-quick", "group": "build", "problemMatcher": ["$gcc"], "presentation": { "reveal": "always", "panel": "shared" } @@ -36,7 +36,7 @@ { "label": "CMake: Build (release)", "type": "shell", - "command": "cmake --build --preset release", + "command": "cmake --preset release && cmake --build --preset release", "group": "build", "problemMatcher": ["$gcc"], "presentation": { "reveal": "always", "panel": "shared" } @@ -44,23 +44,23 @@ { "label": "CMake: Build (sanitizer)", "type": "shell", - "command": "cmake --build --preset sanitizer", + "command": "cmake --preset sanitizer && cmake --build --preset sanitizer", "group": "build", "problemMatcher": ["$gcc"], "presentation": { "reveal": "always", "panel": "shared" } }, { - "label": "CTest: Run All", + "label": "Test: Run All", "type": "shell", - "command": "ctest --preset dev", + "command": "${workspaceFolder}/build/teststochtree", "group": { "kind": "test", "isDefault": true }, "problemMatcher": [], "presentation": { "reveal": "always", "panel": "shared" } }, { - "label": "CTest: Run All (sanitizer)", + "label": "Test: Run All (sanitizer)", "type": "shell", - "command": "ctest --preset sanitizer", + "command": "${workspaceFolder}/build/teststochtree", "group": "test", "problemMatcher": [], "presentation": { "reveal": "always", "panel": "shared" } diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f8a70ff..753c1975 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -178,27 +178,18 @@ endif() if(BUILD_TEST) include(FetchContent) - set(GTEST_SUBMODULE_DIR "${PROJECT_SOURCE_DIR}/deps/googletest") - if(EXISTS "${GTEST_SUBMODULE_DIR}/CMakeLists.txt") - # Use the local submodule — no network required. - # Initialize with: git submodule update --init deps/googletest - message(STATUS "GoogleTest: using local submodule at ${GTEST_SUBMODULE_DIR}") - FetchContent_Declare( - googletest - SOURCE_DIR "${GTEST_SUBMODULE_DIR}" - ) - else() - # Fall back to GitHub fetch (CI, shallow clones, or submodule not initialized). - if (NOT DEFINED GOOGLETEST_GIT_REPO) - set(GOOGLETEST_GIT_REPO https://github.com/google/googletest.git) - endif() - message(STATUS "GoogleTest: fetching from ${GOOGLETEST_GIT_REPO}") - FetchContent_Declare( - googletest - GIT_REPOSITORY ${GOOGLETEST_GIT_REPO} - GIT_TAG 6910c9d9165801d8827d628cb72eb7ea9dd538c5 # release-1.16.0 - ) + # GoogleTest is fetched on first configure and cached in build/_deps/. + # Re-download is skipped automatically when FETCHCONTENT_UPDATES_DISCONNECTED=ON + # (set in the dev and sanitizer presets), so no submodule or internet access + # is needed after the initial cmake --preset dev. + if (NOT DEFINED GOOGLETEST_GIT_REPO) + set(GOOGLETEST_GIT_REPO https://github.com/google/googletest.git) endif() + FetchContent_Declare( + googletest + GIT_REPOSITORY ${GOOGLETEST_GIT_REPO} + GIT_TAG 6910c9d9165801d8827d628cb72eb7ea9dd538c5 # release-1.16.0 + ) # For Windows: Prevent overriding the parent project's compiler/linker settings set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) From 3714c2e034afb0e9193ee3bf4b6546b24a2d4e18 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 13:32:42 -0400 Subject: [PATCH 09/64] Added probit function --- include/stochtree/distributions.h | 34 +++++++++++++++++++++++++++++++ include/stochtree/probit.h | 31 ++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 include/stochtree/probit.h diff --git a/include/stochtree/distributions.h b/include/stochtree/distributions.h index f31b7b16..3cf4120a 100644 --- a/include/stochtree/distributions.h +++ b/include/stochtree/distributions.h @@ -2,6 +2,7 @@ #define STOCHTREE_DISTRIBUTIONS_H #include #include +#include /*! * \brief A collection of random number generation utilities. * @@ -12,6 +13,23 @@ namespace StochTree { +/*! + * Standard normal cumulative distribution function, implemented via the complementary error function. + */ +inline double norm_cdf(double x) { + return 0.5 * boost::math::erfc(-x / std::sqrt(2.0)); +} + +/*! + * Standard normal quantile function (inverse CDF), implemented via the inverse complementary error function. + */ +inline double norm_inv_cdf(double p) { + return -std::sqrt(2.0) * boost::math::erfc_inv(2.0 * p); +} + +/*! Precomputed standard normal CDF at 0 */ +static constexpr double Phi_0 = 0.5; + /*! * Generate a standard uniform random variate to 53 bits of precision via two mersenne twisters, see: * https://github.com/numpy/numpy/blob/0d7986494b39ace565afda3de68be528ddade602/numpy/random/src/mt19937/mt19937.h#L56 @@ -325,6 +343,22 @@ inline int sample_discrete_stateless(std::mt19937& gen, std::vector& wei return weights.size() - 1; } +/*! + * Generate a single sample from a truncated standard normal distribution, bounded above by 0. + */ +inline double sample_std_truncnorm_upper(std::mt19937& gen) { + double uniform_draw = standard_uniform_draw_53bit(gen); + return norm_inv_cdf(uniform_draw * Phi_0); +} + +/*! + * Generate a single sample from a truncated standard normal distribution, bounded below by 0. + */ +inline double sample_std_truncnorm_lower(std::mt19937& gen) { + double uniform_draw = standard_uniform_draw_53bit(gen); + return norm_inv_cdf(uniform_draw + (1 - uniform_draw) * Phi_0); +} + } // namespace StochTree #endif // STOCHTREE_DISTRIBUTIONS_H \ No newline at end of file diff --git a/include/stochtree/probit.h b/include/stochtree/probit.h new file mode 100644 index 00000000..79dd0558 --- /dev/null +++ b/include/stochtree/probit.h @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2024 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_PROBIT_H_ +#define STOCHTREE_PROBIT_H_ + +#include + +namespace StochTree { + +void sample_probit_latent_outcome(std::mt19937& gen, double* outcome, double* conditional_mean, double* latent_outcome, int n) { + double uniform_draw_std; + double uniform_draw_trunc; + double quantile; + for (int i = 0; i < n; i++) { + uniform_draw_std = standard_uniform_draw_53bit(gen); + quantile = norm_cdf(0 - conditional_mean[i]); + if (outcome[i] == 1.0) { + uniform_draw_trunc = quantile + uniform_draw_std * (1.0 - quantile); + latent_outcome[i] = norm_inv_cdf(uniform_draw_trunc) + conditional_mean[i]; + } else { + uniform_draw_trunc = uniform_draw_std * quantile; + latent_outcome[i] = norm_inv_cdf(uniform_draw_trunc) + conditional_mean[i]; + } + } +} + +} // namespace StochTree + +#endif // STOCHTREE_PROBIT_H_ \ No newline at end of file From de4c5c0d1ab89de8a5ece524fcf02c8ab84b9457 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 14:34:10 -0400 Subject: [PATCH 10/64] Added probit to debug program --- debug/bart_debug.cpp | 228 ++++++++++++++++++++------ include/stochtree/partition_tracker.h | 1 + 2 files changed, 179 insertions(+), 50 deletions(-) diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index eeba33ae..4f17f486 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -11,14 +11,17 @@ #include #include +#include #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -28,63 +31,95 @@ static constexpr double kPi = 3.14159265358979323846; // ---- Data ------------------------------------------------------------ -struct Dataset { +struct RegressionDataset { Eigen::Matrix X; Eigen::VectorXd y; }; -// DGP: y = sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) -Dataset generate_data(int n, int p, std::mt19937& rng) { +struct ProbitDataset { + Eigen::Matrix X; + Eigen::VectorXd y; + Eigen::VectorXd Z; +}; + +// DGP: y ~ sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) +RegressionDataset generate_constant_leaf_regression_data(int n, int p, std::mt19937& rng) { std::uniform_real_distribution unif(0.0, 1.0); std::normal_distribution normal(0.0, 1.0); - Dataset d; + RegressionDataset d; d.X.resize(n, p); d.y.resize(n); for (int i = 0; i < n; i++) for (int j = 0; j < p; j++) d.X(i, j) = unif(rng); for (int i = 0; i < n; i++) - d.y(i) = std::sin(2.0 * kPi * d.X(i, 0)) - + 0.5 * d.X(i, 1) - - 1.5 * d.X(i, 2) - + normal(rng); + d.y(i) = std::sin(2.0 * kPi * d.X(i, 0)) + 0.5 * d.X(i, 1) - 1.5 * d.X(i, 2) + normal(rng); return d; } -// ---- Scenario 0: homoskedastic constant-leaf BART ------------------- +// DGP +// --- +// Z ~ sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) +// y = 1{Z > 0} +ProbitDataset generate_probit_data(int n, int p, std::mt19937& rng) { + std::uniform_real_distribution unif(0.0, 1.0); + std::normal_distribution normal(0.0, 1.0); + Eigen::VectorXd Z; + ProbitDataset d; + d.X.resize(n, p); + d.y.resize(n); + d.Z.resize(n); + for (int i = 0; i < n; i++) + for (int j = 0; j < p; j++) + d.X(i, j) = unif(rng); + for (int i = 0; i < n; i++) { + d.Z(i) = std::sin(2.0 * kPi * d.X(i, 0)) + 0.5 * d.X(i, 1) - 1.5 * d.X(i, 2) + normal(rng); + d.y(i) = (d.Z(i) > 0) ? 1.0 : 0.0; + } + return d; +} -void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc) { - constexpr int num_threads = 1; - constexpr int cutpoint_grid_size = 100; - std::mt19937 rng(42); +// ---- Shared sampler loop -------------------------------------------- +// +// Runs GFR warmup then MCMC sampling, both using the same forest/leaf/variance +// setup. The two scenario-specific hooks are: +// +// post_iter(tracker, global_variance) — called after every forest sample in +// both GFR and MCMC (e.g. sample global variance, or augment latent Z). +// +// report_results(preds, global_variance) — called once after all samples are +// collected; receives the flat column-major predictions matrix and the +// final global variance value. - Dataset data = generate_data(n, p, rng); - double y_bar = data.y.mean(); - Eigen::VectorXd resid_vec = data.y.array() - y_bar; +using PostIterFn = std::function; +using ReportFn = std::function&, double)>; - StochTree::ForestDataset dataset; - dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); - StochTree::ColumnVector residual(resid_vec.data(), n); +void run_bart_sampler(int n, int p, int num_trees, int num_gfr, int num_mcmc, + StochTree::ForestDataset& dataset, + StochTree::ColumnVector& residual, std::mt19937& rng, + PostIterFn post_iter, ReportFn report_results) { + constexpr int num_threads = 1; + constexpr int cutpoint_grid_size = 100; std::vector feature_types(p, StochTree::FeatureType::kNumeric); std::vector var_weights(p, 1.0 / p); - std::vector sweep_indices; + std::vector sweep_indices(num_trees); + std::iota(sweep_indices.begin(), sweep_indices.end(), 0); StochTree::TreePrior tree_prior(0.95, 2.0, /*min_samples_leaf=*/5); StochTree::ForestContainer forest_samples(num_trees, /*output_dim=*/1, /*leaf_constant=*/true, /*exponentiated=*/false); StochTree::TreeEnsemble active_forest(num_trees, 1, true, false); StochTree::ForestTracker tracker(dataset.GetCovariates(), feature_types, num_trees, n); - double leaf_scale = 1.0 / num_trees; - StochTree::GaussianConstantLeafModel leaf_model(leaf_scale); + active_forest.SetLeafValue(0.0); + UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, false, std::minus()); + tracker.UpdatePredictions(&active_forest, dataset); + StochTree::GaussianConstantLeafModel leaf_model(1.0 / num_trees); double global_variance = 1.0; - constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior - StochTree::GlobalHomoskedasticVarianceModel var_model; - // GFR warmup — no samples stored std::cout << "[GFR] " << num_gfr << " warmup iterations...\n"; - bool pre_initialized = false; + bool pre_initialized = true; for (int i = 0; i < num_gfr; i++) { StochTree::GFRSampleOneIter< StochTree::GaussianConstantLeafModel, @@ -93,13 +128,10 @@ void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc) { dataset, residual, tree_prior, rng, var_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, - /*backfitting=*/true, /*num_features_subsample=*/-1, num_threads); - global_variance = var_model.SampleVarianceParameter( - residual.GetData(), a_sigma, b_sigma, rng); - pre_initialized = true; + /*backfitting=*/true, /*num_features_subsample=*/p, num_threads); + post_iter(tracker, global_variance); } - // MCMC — store samples std::cout << "[MCMC] " << num_mcmc << " sampling iterations...\n"; for (int i = 0; i < num_mcmc; i++) { StochTree::MCMCSampleOneIter< @@ -110,42 +142,138 @@ void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc) { var_weights, sweep_indices, global_variance, /*keep_forest=*/true, /*pre_initialized=*/true, /*backfitting=*/true, num_threads); - global_variance = var_model.SampleVarianceParameter( - residual.GetData(), a_sigma, b_sigma, rng); + post_iter(tracker, global_variance); } // Posterior predictions: column-major, element [j*n + i] = sample j, obs i - std::vector preds = forest_samples.Predict(dataset); - double rmse_sum = 0.0; - for (int i = 0; i < n; i++) { - double mu_hat = y_bar; - for (int j = 0; j < num_mcmc; j++) - mu_hat += preds[static_cast(j * n + i)] / num_mcmc; - double err = mu_hat - data.y(i); - rmse_sum += err * err; - } + report_results(forest_samples.Predict(dataset), global_variance); +} + +// ---- Scenario 0: homoskedastic constant-leaf BART ------------------- + +void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { + std::mt19937 rng(seed); + + RegressionDataset data = generate_constant_leaf_regression_data(n, p, rng); + double y_bar = data.y.mean(); + double y_std = std::sqrt((data.y.array() - y_bar).square().sum() / (data.y.size() - 1)); + Eigen::VectorXd resid_vec = (data.y.array() - y_bar) / y_std; // standardize + + StochTree::ForestDataset dataset; + dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + StochTree::ColumnVector residual(resid_vec.data(), n); + + constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior + StochTree::GlobalHomoskedasticVarianceModel var_model; + + auto post_iter = [&](StochTree::ForestTracker&, double& global_variance) { + global_variance = var_model.SampleVarianceParameter(residual.GetData(), a_sigma, b_sigma, rng); + }; - std::cout << "\nScenario 0 (HomoskedasticBART):\n" - << " RMSE: " << std::sqrt(rmse_sum / n) << "\n" - << " sigma (last sample): " << std::sqrt(global_variance) << "\n" - << " sigma (truth): 1.0\n"; + auto report = [&](const std::vector& preds, double global_variance) { + double rmse_sum = 0.0; + for (int i = 0; i < n; i++) { + double mu_hat = 0.0; + for (int j = 0; j < num_mcmc; j++) + mu_hat += preds[static_cast(j * n + i)] / num_mcmc; + double err = (mu_hat * y_std + y_bar) - data.y(i); + rmse_sum += err * err; + } + std::cout << "\nScenario 0 (Homoskedastic BART):\n" + << " RMSE: " << std::sqrt(rmse_sum / n) << "\n" + << " sigma (last sample): " << std::sqrt(global_variance) * y_std << "\n" + << " sigma (truth): 1.0\n"; + }; + + run_bart_sampler(n, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, post_iter, report); +} + +// ---- Scenario 1: constant-leaf probit BART ------------------- + +void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { + std::mt19937 rng(seed); + + ProbitDataset data = generate_probit_data(n, p, rng); + double y_bar = StochTree::norm_cdf(data.y.mean()); + Eigen::VectorXd y_vec = data.y.array(); + Eigen::VectorXd Z_vec = (data.y.array() - y_bar); + + StochTree::ForestDataset dataset; + dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + StochTree::ColumnVector residual(Z_vec.data(), n); + + // Data augmentation: sample latent Z given current forest predictions + auto post_iter = [&](StochTree::ForestTracker& tracker, double&) { + StochTree::sample_probit_latent_outcome( + rng, y_vec.data(), tracker.GetSumPredictions(), residual.GetData().data(), n); + }; + + auto report = [&](const std::vector& preds, double global_variance) { + double rmse_sum = 0.0; + for (int i = 0; i < n; i++) { + double mu_hat = 0.0; + for (int j = 0; j < num_mcmc; j++) + mu_hat += preds[static_cast(j * n + i)] / num_mcmc; + double err = (mu_hat + y_bar) - data.Z(i); + rmse_sum += err * err; + } + std::cout << "\nScenario 1 (Probit BART):\n" + << " RMSE: " << std::sqrt(rmse_sum / n) << "\n" + << " sigma (truth): 1.0\n"; + }; + + run_bart_sampler(n, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, post_iter, report); } // ---- Main ----------------------------------------------------------- int main(int argc, char** argv) { - int scenario = 0; - if (argc > 1) scenario = std::stoi(argv[1]); + int scenario = 1; + int n = 500; + int p = 5; + int num_trees = 200; + int num_gfr = 20; + int num_mcmc = 100; + int seed = 1234; - constexpr int n = 200, p = 5, num_trees = 200, num_gfr = 20, num_mcmc = 100; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if ((arg == "--scenario" || arg == "--n" || arg == "--p" || + arg == "--num_trees" || arg == "--num_gfr" || arg == "--num_mcmc" || arg == "--seed") && + i + 1 < argc) { + int val = std::stoi(argv[++i]); + if (arg == "--scenario") + scenario = val; + else if (arg == "--n") + n = val; + else if (arg == "--p") + p = val; + else if (arg == "--num_trees") + num_trees = val; + else if (arg == "--num_gfr") + num_gfr = val; + else if (arg == "--num_mcmc") + num_mcmc = val; + else if (arg == "--num_mcmc") + seed = val; + } else { + std::cerr << "Unknown or incomplete argument: " << arg << "\n" + << "Usage: bart_debug [--scenario N] [--n N] [--p N]" + " [--num_trees N] [--num_gfr N] [--num_mcmc N]\n"; + return 1; + } + } switch (scenario) { case 0: run_scenario_0(n, p, num_trees, num_gfr, num_mcmc); break; + case 1: + run_scenario_1(n, p, num_trees, num_gfr, num_mcmc); + break; default: std::cerr << "Unknown scenario " << scenario - << ". Available scenarios: 0 (HomoskedasticBART)\n"; + << ". Available scenarios: 0 (Homoskedastic BART), 1 (Probit BART)\n"; return 1; } return 0; diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index 6f3a94ae..164ffae0 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -87,6 +87,7 @@ class ForestTracker { SampleNodeMapper* GetSampleNodeMapper() { return sample_node_mapper_.get(); } UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() { return unsorted_node_sample_tracker_.get(); } SortedNodeSampleTracker* GetSortedNodeSampleTracker() { return sorted_node_sample_tracker_.get(); } + double* GetSumPredictions() { return sum_predictions_.data(); } int GetNumObservations() { return num_observations_; } int GetNumTrees() { return num_trees_; } int GetNumFeatures() { return num_features_; } From 4abac7665238ea9f21de1ef0cb57442fff66cb7b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 14:34:18 -0400 Subject: [PATCH 11/64] Updated build infrastructure --- .gitignore | 1 + .vscode/launch.json | 23 +++++++++++++++++++++-- .vscode/tasks.json | 8 ++++++++ CMakePresets.json | 19 +++++++++++++++++++ 4 files changed, 49 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 1fb03657..c43ab93a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.DS_Store lib/ build/ +build-release*/ .vscode/positron/ xcode/ *.json diff --git a/.vscode/launch.json b/.vscode/launch.json index 4b2abca8..b8dcf442 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,7 +6,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["0"], + "args": ["--scenario", "0", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "20", "--num_mcmc", "100", "--seed", "1234"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (dev-quick)" }, @@ -33,7 +33,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["0"], + "args": ["--scenario", "0", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "20", "--num_mcmc", "100", "--seed", "1234"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev-quick)" @@ -58,6 +58,25 @@ "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev)" }, + { + "name": "bart_debug (macOS, Release)", + "type": "lldb", + "request": "launch", + "program": "${workspaceFolder}/build-release-drivers/bart_debug", + "args": ["--scenario", "0", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "20", "--num_mcmc", "100", "--seed", "101"], + "cwd": "${workspaceFolder}", + "preLaunchTask": "CMake: Build (release-drivers)" + }, + { + "name": "bart_debug (Linux/Container, Release)", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/build-release-drivers/bart_debug", + "args": ["--scenario", "0", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "20", "--num_mcmc", "100", "--seed", "101"], + "cwd": "${workspaceFolder}", + "MIMode": "gdb", + "preLaunchTask": "CMake: Build (release-drivers)" + }, { "name": "Python Debugger: Current File", "type": "debugpy", diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 2ffeecb8..e21400a0 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -41,6 +41,14 @@ "problemMatcher": ["$gcc"], "presentation": { "reveal": "always", "panel": "shared" } }, + { + "label": "CMake: Build (release-drivers)", + "type": "shell", + "command": "cmake --preset release-drivers && cmake --build --preset release-drivers", + "group": "build", + "problemMatcher": ["$gcc"], + "presentation": { "reveal": "always", "panel": "shared" } + }, { "label": "CMake: Build (sanitizer)", "type": "shell", diff --git a/CMakePresets.json b/CMakePresets.json index a1134ab8..2a4730a7 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -41,6 +41,19 @@ "CMAKE_EXPORT_COMPILE_COMMANDS": "OFF" } }, + { + "name": "release-drivers", + "displayName": "Release (with debug drivers)", + "description": "Optimized release build that also compiles the debug driver executables", + "binaryDir": "${sourceDir}/build-release-drivers", + "cacheVariables": { + "USE_DEBUG": "OFF", + "BUILD_DEBUG_TARGETS": "ON", + "BUILD_TEST": "OFF", + "BUILD_PYTHON": "OFF", + "CMAKE_EXPORT_COMPILE_COMMANDS": "OFF" + } + }, { "name": "sanitizer", "displayName": "Sanitizer (ASAN + UBSAN)", @@ -76,6 +89,12 @@ "configurePreset": "release", "jobs": 0 }, + { + "name": "release-drivers", + "displayName": "Release (with debug drivers)", + "configurePreset": "release-drivers", + "jobs": 0 + }, { "name": "sanitizer", "displayName": "Sanitizer", From a8b35d6c459b3e5c4d18e0eccaa9ed65e8422a0d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 14:51:44 -0400 Subject: [PATCH 12/64] Fixed probit bugs --- .vscode/launch.json | 8 ++++---- debug/bart_debug.cpp | 30 ++++++++++++++++++++++-------- include/stochtree/probit.h | 12 ++++++++---- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index b8dcf442..df75094b 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,7 +6,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["--scenario", "0", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "20", "--num_mcmc", "100", "--seed", "1234"], + "args": ["--scenario", "1", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (dev-quick)" }, @@ -33,7 +33,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["--scenario", "0", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "20", "--num_mcmc", "100", "--seed", "1234"], + "args": ["--scenario", "1", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev-quick)" @@ -63,7 +63,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bart_debug", - "args": ["--scenario", "0", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "20", "--num_mcmc", "100", "--seed", "101"], + "args": ["--scenario", "1", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (release-drivers)" }, @@ -72,7 +72,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bart_debug", - "args": ["--scenario", "0", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "20", "--num_mcmc", "100", "--seed", "101"], + "args": ["--scenario", "1", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (release-drivers)" diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 4f17f486..80baa1c2 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -152,7 +152,14 @@ void run_bart_sampler(int n, int p, int num_trees, int num_gfr, int num_mcmc, // ---- Scenario 0: homoskedastic constant-leaf BART ------------------- void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { - std::mt19937 rng(seed); + int rng_seed; + if (seed == -1) { + std::random_device rd; + rng_seed = rd(); + } else { + rng_seed = seed; + } + std::mt19937 rng(rng_seed); RegressionDataset data = generate_constant_leaf_regression_data(n, p, rng); double y_bar = data.y.mean(); @@ -191,7 +198,14 @@ void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc, int // ---- Scenario 1: constant-leaf probit BART ------------------- void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { - std::mt19937 rng(seed); + int rng_seed; + if (seed == -1) { + std::random_device rd; + rng_seed = rd(); + } else { + rng_seed = seed; + } + std::mt19937 rng(rng_seed); ProbitDataset data = generate_probit_data(n, p, rng); double y_bar = StochTree::norm_cdf(data.y.mean()); @@ -205,7 +219,7 @@ void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int // Data augmentation: sample latent Z given current forest predictions auto post_iter = [&](StochTree::ForestTracker& tracker, double&) { StochTree::sample_probit_latent_outcome( - rng, y_vec.data(), tracker.GetSumPredictions(), residual.GetData().data(), n); + rng, y_vec.data(), tracker.GetSumPredictions(), residual.GetData().data(), y_bar, n); }; auto report = [&](const std::vector& preds, double global_variance) { @@ -232,7 +246,7 @@ int main(int argc, char** argv) { int n = 500; int p = 5; int num_trees = 200; - int num_gfr = 20; + int num_gfr = 10; int num_mcmc = 100; int seed = 1234; @@ -254,22 +268,22 @@ int main(int argc, char** argv) { num_gfr = val; else if (arg == "--num_mcmc") num_mcmc = val; - else if (arg == "--num_mcmc") + else if (arg == "--seed") seed = val; } else { std::cerr << "Unknown or incomplete argument: " << arg << "\n" << "Usage: bart_debug [--scenario N] [--n N] [--p N]" - " [--num_trees N] [--num_gfr N] [--num_mcmc N]\n"; + " [--num_trees N] [--num_gfr N] [--num_mcmc N] [--seed N]\n"; return 1; } } switch (scenario) { case 0: - run_scenario_0(n, p, num_trees, num_gfr, num_mcmc); + run_scenario_0(n, p, num_trees, num_gfr, num_mcmc, seed); break; case 1: - run_scenario_1(n, p, num_trees, num_gfr, num_mcmc); + run_scenario_1(n, p, num_trees, num_gfr, num_mcmc, seed); break; default: std::cerr << "Unknown scenario " << scenario diff --git a/include/stochtree/probit.h b/include/stochtree/probit.h index 79dd0558..300c25ee 100644 --- a/include/stochtree/probit.h +++ b/include/stochtree/probit.h @@ -9,20 +9,24 @@ namespace StochTree { -void sample_probit_latent_outcome(std::mt19937& gen, double* outcome, double* conditional_mean, double* latent_outcome, int n) { +void sample_probit_latent_outcome(std::mt19937& gen, double* outcome, double* conditional_mean, double* partial_residual, double y_bar, int n) { double uniform_draw_std; double uniform_draw_trunc; double quantile; + double cond_mean; + double latent_outcome; for (int i = 0; i < n; i++) { + cond_mean = conditional_mean[i] + y_bar; uniform_draw_std = standard_uniform_draw_53bit(gen); - quantile = norm_cdf(0 - conditional_mean[i]); + quantile = norm_cdf(0 - cond_mean); if (outcome[i] == 1.0) { uniform_draw_trunc = quantile + uniform_draw_std * (1.0 - quantile); - latent_outcome[i] = norm_inv_cdf(uniform_draw_trunc) + conditional_mean[i]; + latent_outcome = norm_inv_cdf(uniform_draw_trunc) + cond_mean; } else { uniform_draw_trunc = uniform_draw_std * quantile; - latent_outcome[i] = norm_inv_cdf(uniform_draw_trunc) + conditional_mean[i]; + latent_outcome = norm_inv_cdf(uniform_draw_trunc) + cond_mean; } + partial_residual[i] = latent_outcome - cond_mean; } } From 8a21de8f6d4c694e965bd87ba576478d53ad43dd Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 14:57:57 -0400 Subject: [PATCH 13/64] Added comments --- debug/bart_debug.cpp | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 80baa1c2..b1059b17 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -98,29 +98,36 @@ void run_bart_sampler(int n, int p, int num_trees, int num_gfr, int num_mcmc, StochTree::ForestDataset& dataset, StochTree::ColumnVector& residual, std::mt19937& rng, PostIterFn post_iter, ReportFn report_results) { + // Single-threaded with default cutpoint grid size (for now) constexpr int num_threads = 1; constexpr int cutpoint_grid_size = 100; + // Model parameters for split rule selection and tree sweeps std::vector feature_types(p, StochTree::FeatureType::kNumeric); std::vector var_weights(p, 1.0 / p); std::vector sweep_indices(num_trees); std::iota(sweep_indices.begin(), sweep_indices.end(), 0); + // Ephemeral sampler state StochTree::TreePrior tree_prior(0.95, 2.0, /*min_samples_leaf=*/5); StochTree::ForestContainer forest_samples(num_trees, /*output_dim=*/1, /*leaf_constant=*/true, /*exponentiated=*/false); StochTree::TreeEnsemble active_forest(num_trees, 1, true, false); StochTree::ForestTracker tracker(dataset.GetCovariates(), feature_types, num_trees, n); + // Initialize forest and tracker predictions to 0 (after standardization, this is the best initial guess) active_forest.SetLeafValue(0.0); UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, false, std::minus()); tracker.UpdatePredictions(&active_forest, dataset); + // Initialize leaf model and global variance for sampling iterations StochTree::GaussianConstantLeafModel leaf_model(1.0 / num_trees); double global_variance = 1.0; + // Run GFR std::cout << "[GFR] " << num_gfr << " warmup iterations...\n"; bool pre_initialized = true; for (int i = 0; i < num_gfr; i++) { + // Sample forest StochTree::GFRSampleOneIter< StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>( @@ -129,11 +136,15 @@ void run_bart_sampler(int n, int p, int num_trees, int num_gfr, int num_mcmc, var_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, /*backfitting=*/true, /*num_features_subsample=*/p, num_threads); + + // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) post_iter(tracker, global_variance); } + // Run MCMC std::cout << "[MCMC] " << num_mcmc << " sampling iterations...\n"; for (int i = 0; i < num_mcmc; i++) { + // Sample forest StochTree::MCMCSampleOneIter< StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>( @@ -142,16 +153,19 @@ void run_bart_sampler(int n, int p, int num_trees, int num_gfr, int num_mcmc, var_weights, sweep_indices, global_variance, /*keep_forest=*/true, /*pre_initialized=*/true, /*backfitting=*/true, num_threads); + + // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) post_iter(tracker, global_variance); } - // Posterior predictions: column-major, element [j*n + i] = sample j, obs i + // Analyze posterior predictions (column-major, element [j*n + i] = sample j, obs i) report_results(forest_samples.Predict(dataset), global_variance); } // ---- Scenario 0: homoskedastic constant-leaf BART ------------------- void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { + // Allow seed to be non-deterministic if set to sentinel value of -1 int rng_seed; if (seed == -1) { std::random_device rd; @@ -161,22 +175,27 @@ void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc, int } std::mt19937 rng(rng_seed); + // Generate data RegressionDataset data = generate_constant_leaf_regression_data(n, p, rng); double y_bar = data.y.mean(); double y_std = std::sqrt((data.y.array() - y_bar).square().sum() / (data.y.size() - 1)); Eigen::VectorXd resid_vec = (data.y.array() - y_bar) / y_std; // standardize + // Initialize dataset and residual vector for sampler StochTree::ForestDataset dataset; dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); StochTree::ColumnVector residual(resid_vec.data(), n); + // Initialize global error variance model constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior StochTree::GlobalHomoskedasticVarianceModel var_model; + // Lambda function for sampling global error variance after each forest sample auto post_iter = [&](StochTree::ForestTracker&, double& global_variance) { global_variance = var_model.SampleVarianceParameter(residual.GetData(), a_sigma, b_sigma, rng); }; + // Lambda function for reporting RMSE and last draw of global error variance model auto report = [&](const std::vector& preds, double global_variance) { double rmse_sum = 0.0; for (int i = 0; i < n; i++) { @@ -192,12 +211,14 @@ void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc, int << " sigma (truth): 1.0\n"; }; + // Dispatch BART sampler run_bart_sampler(n, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, post_iter, report); } // ---- Scenario 1: constant-leaf probit BART ------------------- void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { + // Allow seed to be non-deterministic if set to sentinel value of -1 int rng_seed; if (seed == -1) { std::random_device rd; @@ -207,21 +228,24 @@ void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int } std::mt19937 rng(rng_seed); + // Generate data ProbitDataset data = generate_probit_data(n, p, rng); double y_bar = StochTree::norm_cdf(data.y.mean()); Eigen::VectorXd y_vec = data.y.array(); Eigen::VectorXd Z_vec = (data.y.array() - y_bar); + // Initialize dataset and residual vector for sampler StochTree::ForestDataset dataset; dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); StochTree::ColumnVector residual(Z_vec.data(), n); - // Data augmentation: sample latent Z given current forest predictions + // Lambda function for probit data augmentation sampling step (after each forest sample) auto post_iter = [&](StochTree::ForestTracker& tracker, double&) { StochTree::sample_probit_latent_outcome( rng, y_vec.data(), tracker.GetSumPredictions(), residual.GetData().data(), y_bar, n); }; + // Lambda function for reporting RMSE auto report = [&](const std::vector& preds, double global_variance) { double rmse_sum = 0.0; for (int i = 0; i < n; i++) { @@ -236,6 +260,7 @@ void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int << " sigma (truth): 1.0\n"; }; + // Dispatch BART sampler run_bart_sampler(n, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, post_iter, report); } From 02e95117b7c32023e45f8d37dbc045b804036d93 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 15:31:14 -0400 Subject: [PATCH 14/64] Updated BCF and BART programs --- .vscode/launch.json | 23 +++- debug/bart_debug.cpp | 6 +- debug/bcf_debug.cpp | 290 +++++++++++++++++++++++++++---------------- 3 files changed, 210 insertions(+), 109 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index df75094b..814cf2dc 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -15,7 +15,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build/bcf_debug", - "args": ["0"], + "args": ["--scenario", "0", "--n", "200", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (dev-quick)" }, @@ -43,7 +43,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build/bcf_debug", - "args": ["0"], + "args": ["--scenario", "0", "--n", "200", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev-quick)" @@ -77,6 +77,25 @@ "MIMode": "gdb", "preLaunchTask": "CMake: Build (release-drivers)" }, + { + "name": "bcf_debug (macOS, Release)", + "type": "lldb", + "request": "launch", + "program": "${workspaceFolder}/build-release-drivers/bcf_debug", + "args": ["--scenario", "0", "--n", "200", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "cwd": "${workspaceFolder}", + "preLaunchTask": "CMake: Build (release-drivers)" + }, + { + "name": "bcf_debug (Linux/Container, Release)", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/build-release-drivers/bcf_debug", + "args": ["--scenario", "0", "--n", "200", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "cwd": "${workspaceFolder}", + "MIMode": "gdb", + "preLaunchTask": "CMake: Build (release-drivers)" + }, { "name": "Python Debugger: Current File", "type": "debugpy", diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index b1059b17..91990e37 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -1,7 +1,9 @@ /* - * BART debug driver. The first CLI argument selects the scenario (default: 0). + * BART debug program. The first CLI argument selects the scenario (default: 0). + * + * Usage: bart_debug [--scenario N] [--n N] [--p N] [--num_trees N] + * [--num_gfr N] [--num_mcmc N] [--seed N] * - * Usage: bart_debug [scenario] * 0 Homoskedastic constant-leaf BART * DGP: y = sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + eps, eps ~ N(0,1) * diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index 6911097d..52d1ad60 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -1,7 +1,9 @@ /* - * BCF debug driver. The first CLI argument selects the scenario (default: 0). + * BCF debug program + * + * Usage: bcf_debug [--scenario N] [--n N] [--p N] [--num_trees_mu N] [--num_trees_tau N] + * [--num_gfr N] [--num_mcmc N] [--seed N] * - * Usage: bcf_debug [scenario] * 0 Two-forest BCF: constant-leaf mu, univariate-leaf tau (Z as basis) * DGP: mu(x) = 2*sin(pi*x1) + 0.5*x2 * tau(x) = 1 + x3 @@ -10,18 +12,6 @@ * * Add scenarios here as the BCFSampler API develops (heteroskedastic, * random effects, propensity weighting, etc.). - * - * Algorithm overview - * ------------------ - * Both forests share a single ColumnVector residual. Alternating GFR/MCMC - * steps for mu and tau each run backfitting, so the residual after each - * step correctly reflects the other forest's current contribution: - * - * After mu step: residual ≈ y - y_bar - mu_hat - * After tau step: residual ≈ y - y_bar - mu_hat - tau_hat*z - * - * The tau forest uses z as a univariate basis (AddBasis), so its prediction - * for observation i is tau_leaf(i) * z(i), and backfitting is z-aware. */ #include @@ -35,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -76,156 +67,245 @@ BCFDataset generate_data(int n, int p, std::mt19937& rng) { return d; } -// ---- Scenario 0: constant-leaf mu + univariate-leaf tau (Z basis) --- - -void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc) { +// ---- Shared sampler loop -------------------------------------------- +// +// Runs alternating mu/tau GFR warmup then MCMC, sharing a single residual. +// The two scenario-specific hooks are: +// +// post_iter(mu_tracker, global_variance) — called after each full mu+tau +// iteration (e.g. sample global variance). +// +// report_results(mu_preds, tau_preds, global_variance) — called once after +// all samples are collected; receives column-major prediction matrices +// and the final global variance value. + +using PostIterFn = std::function; +using BCFReportFn = std::function&, const std::vector&, double)>; + +void run_bcf_sampler(int n, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, + StochTree::ForestDataset& dataset, + StochTree::ColumnVector& residual, std::mt19937& rng, + PostIterFn post_iter, BCFReportFn report_results) { + // Single-threaded with default cutpoint grid size (for now) constexpr int num_threads = 1; constexpr int cutpoint_grid_size = 100; - std::mt19937 rng(42); - - BCFDataset data = generate_data(n, p, rng); - double y_bar = data.y.mean(); - Eigen::VectorXd resid_vec = data.y.array() - y_bar; - - // Mu dataset: X covariates only - StochTree::ForestDataset dataset_mu; - dataset_mu.AddCovariates(data.X.data(), n, p, /*row_major=*/true); - - // Tau dataset: X covariates + Z as univariate basis - StochTree::ForestDataset dataset_tau; - dataset_tau.AddCovariates(data.X.data(), n, p, true); - dataset_tau.AddBasis(data.z.data(), n, /*num_col=*/1, /*row_major=*/false); - - // Shared residual - StochTree::ColumnVector residual(resid_vec.data(), n); + // Model parameters for split rule selection and tree sweeps std::vector feature_types(p, StochTree::FeatureType::kNumeric); std::vector var_weights(p, 1.0 / p); - std::vector sweep_indices; - - StochTree::TreePrior tree_prior(0.95, 2.0, /*min_samples_leaf=*/5); + std::vector sweep_indices_mu(num_trees_mu); + std::iota(sweep_indices_mu.begin(), sweep_indices_mu.end(), 0); + std::vector sweep_indices_tau(num_trees_tau); + std::iota(sweep_indices_tau.begin(), sweep_indices_tau.end(), 0); + // Ephemeral sampler state // Mu forest: constant-leaf - StochTree::ForestContainer mu_samples(num_trees, 1, /*leaf_constant=*/true, /*exponentiated=*/false); - StochTree::TreeEnsemble mu_forest(num_trees, 1, true, false); - StochTree::ForestTracker mu_tracker(dataset_mu.GetCovariates(), feature_types, num_trees, n); - double mu_leaf_scale = 1.0 / num_trees; - StochTree::GaussianConstantLeafModel mu_leaf_model(mu_leaf_scale); + StochTree::TreePrior mu_tree_prior(0.95, 2.0, /*min_samples_leaf=*/5); + StochTree::ForestContainer mu_samples(num_trees_mu, /*output_dim=*/1, /*leaf_constant=*/true, /*exponentiated=*/false); + StochTree::TreeEnsemble mu_forest(num_trees_mu, 1, true, false); + StochTree::ForestTracker mu_tracker(dataset.GetCovariates(), feature_types, num_trees_mu, n); + StochTree::GaussianConstantLeafModel mu_leaf_model(1.0 / num_trees_mu); // Tau forest: univariate regression leaf (prediction = leaf_param * z) - StochTree::ForestContainer tau_samples(num_trees, 1, /*leaf_constant=*/false, /*exponentiated=*/false); - StochTree::TreeEnsemble tau_forest(num_trees, 1, false, false); - StochTree::ForestTracker tau_tracker(dataset_tau.GetCovariates(), feature_types, num_trees, n); - double tau_leaf_scale = 1.0 / num_trees; - StochTree::GaussianUnivariateRegressionLeafModel tau_leaf_model(tau_leaf_scale); - + StochTree::TreePrior tau_tree_prior(0.5, 2.0, /*min_samples_leaf=*/5); + StochTree::ForestContainer tau_samples(num_trees_tau, /*output_dim=*/1, /*leaf_constant=*/false, /*exponentiated=*/false); + StochTree::TreeEnsemble tau_forest(num_trees_tau, 1, false, false); + StochTree::ForestTracker tau_tracker(dataset.GetCovariates(), feature_types, num_trees_tau, n); + StochTree::GaussianUnivariateRegressionLeafModel tau_leaf_model(1.0 / num_trees_tau); + + // Initialize mu forest and tracker predictions to 0 + mu_forest.SetLeafValue(0.0); + UpdateResidualEntireForest(mu_tracker, dataset, residual, &mu_forest, false, std::minus()); + mu_tracker.UpdatePredictions(&mu_forest, dataset); + + // Initial tau forest and tracker predictions to 0 + tau_forest.SetLeafValue(0.0); + UpdateResidualEntireForest(tau_tracker, dataset, residual, &tau_forest, false, std::minus()); + tau_tracker.UpdatePredictions(&tau_forest, dataset); + + // Initialize global error variance to 1 (output is standardized) double global_variance = 1.0; - constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior - StochTree::GlobalHomoskedasticVarianceModel var_model; - // GFR warmup — no samples stored + // Run GFR std::cout << "[GFR] " << num_gfr << " warmup iterations...\n"; - bool pre_mu = false, pre_tau = false; + bool pre_initialized = true; for (int i = 0; i < num_gfr; i++) { + // Sample mu forest StochTree::GFRSampleOneIter< StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>( mu_forest, mu_tracker, mu_samples, mu_leaf_model, - dataset_mu, residual, tree_prior, rng, - var_weights, sweep_indices, global_variance, feature_types, - cutpoint_grid_size, /*keep_forest=*/false, pre_mu, - /*backfitting=*/true, /*num_features_subsample=*/-1, num_threads); - pre_mu = true; + dataset, residual, mu_tree_prior, rng, + var_weights, sweep_indices_mu, global_variance, feature_types, + cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, + /*backfitting=*/true, /*num_features_subsample=*/p, num_threads); + // Sample tau forest StochTree::GFRSampleOneIter< StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>( tau_forest, tau_tracker, tau_samples, tau_leaf_model, - dataset_tau, residual, tree_prior, rng, - var_weights, sweep_indices, global_variance, feature_types, - cutpoint_grid_size, false, pre_tau, - true, -1, num_threads); - pre_tau = true; - - global_variance = var_model.SampleVarianceParameter( - residual.GetData(), a_sigma, b_sigma, rng); + dataset, residual, tau_tree_prior, rng, + var_weights, sweep_indices_tau, global_variance, feature_types, + cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, + /*backfitting=*/true, /*num_features_subsample=*/p, num_threads); + + // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) + post_iter(mu_tracker, global_variance); } - // MCMC — store samples + // Run MCMC std::cout << "[MCMC] " << num_mcmc << " sampling iterations...\n"; for (int i = 0; i < num_mcmc; i++) { + // Sample mu forest StochTree::MCMCSampleOneIter< StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>( mu_forest, mu_tracker, mu_samples, mu_leaf_model, - dataset_mu, residual, tree_prior, rng, - var_weights, sweep_indices, global_variance, + dataset, residual, mu_tree_prior, rng, + var_weights, sweep_indices_mu, global_variance, /*keep_forest=*/true, /*pre_initialized=*/true, /*backfitting=*/true, num_threads); + // Sample tau forest StochTree::MCMCSampleOneIter< StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>( tau_forest, tau_tracker, tau_samples, tau_leaf_model, - dataset_tau, residual, tree_prior, rng, - var_weights, sweep_indices, global_variance, - true, true, true, num_threads); + dataset, residual, tau_tree_prior, rng, + var_weights, sweep_indices_tau, global_variance, + /*keep_forest=*/true, /*pre_initialized=*/true, + /*backfitting=*/true, num_threads); - global_variance = var_model.SampleVarianceParameter( - residual.GetData(), a_sigma, b_sigma, rng); + // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) + post_iter(mu_tracker, global_variance); } - // Posterior predictions - // mu_preds[j*n + i] = mu_hat for sample j, obs i (column-major) - // tau_preds[j*n + i] = tau_hat(i)*z(i) (since basis is z) - std::vector mu_preds = mu_samples.Predict(dataset_mu); - std::vector tau_preds = tau_samples.Predict(dataset_tau); + // Analyze posterior predictions (column-major, element [j*n + i] = sample j, obs i) + report_results(mu_samples.Predict(dataset), tau_samples.Predict(dataset), global_variance); +} - double mu_rmse_sum = 0.0; - double tau_rmse_sum = 0.0; - int n_treated = 0; +// ---- Scenario 0: constant-leaf mu + univariate-leaf tau (Z basis) --- - for (int i = 0; i < n; i++) { - double mu_hat = y_bar; - for (int j = 0; j < num_mcmc; j++) - mu_hat += mu_preds[static_cast(j * n + i)] / num_mcmc; - double mu_err = mu_hat - data.mu_true(i); - mu_rmse_sum += mu_err * mu_err; - - // For z=1: tau_preds = tau_hat * 1 = tau_hat, so we can evaluate CATE - if (data.z(i) > 0.5) { +void run_scenario_0(int n, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, int seed = 42) { + // Allow seed to be non-deterministic if set to sentinel value of -1 + int rng_seed; + if (seed == -1) { + std::random_device rd; + rng_seed = rd(); + } else { + rng_seed = seed; + } + std::mt19937 rng(rng_seed); + + // Generate data and standardize outcome + BCFDataset data = generate_data(n, p, rng); + double y_bar = data.y.mean(); + double y_std = std::sqrt((data.y.array() - y_bar).square().mean()); + Eigen::VectorXd resid_vec = (data.y.array() - y_bar) / y_std; // standardize + + // Shared dataset: only tau forest uses the Z basis for leaf regression + StochTree::ForestDataset dataset; + dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + dataset.AddBasis(data.z.data(), n, /*num_col=*/1, /*row_major=*/false); + + // Shared residual + StochTree::ColumnVector residual(resid_vec.data(), n); + + // Global error variance model + constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior + StochTree::GlobalHomoskedasticVarianceModel var_model; + + // Lambda function for sampling global error variance after each mu+tau step + auto post_iter = [&](StochTree::ForestTracker&, double& global_variance) { + global_variance = var_model.SampleVarianceParameter(residual.GetData(), a_sigma, b_sigma, rng); + }; + + // Lambda function for reporting mu/tau RMSE and last draw of global error variance + auto report = [&](const std::vector& mu_preds, const std::vector& tau_preds, + double global_variance) { + double mu_rmse_sum = 0.0, tau_rmse_sum = 0.0, y_rmse_sum = 0.0; + int n_treated = 0; + + for (int i = 0; i < n; i++) { + double y_hat = 0.0; + double mu_hat = 0.0; + for (int j = 0; j < num_mcmc; j++) + mu_hat += mu_preds[static_cast(j * n + i)] / num_mcmc; + mu_rmse_sum += (mu_hat * y_std + y_bar - data.mu_true(i)) * (mu_hat * y_std + y_bar - data.mu_true(i)); + y_hat += mu_hat * y_std + y_bar; + + // For z=1: tau_preds = tau_hat * 1 = tau_hat, so we can evaluate CATE directly double tau_hat = 0.0; for (int j = 0; j < num_mcmc; j++) tau_hat += tau_preds[static_cast(j * n + i)] / num_mcmc; - double tau_err = tau_hat - data.tau_true(i); - tau_rmse_sum += tau_err * tau_err; - n_treated++; + tau_rmse_sum += (tau_hat * y_std - data.tau_true(i)) * (tau_hat * y_std - data.tau_true(i)); + y_hat += tau_hat * data.z(i) * y_std; + y_rmse_sum += (y_hat - data.y(i)) * (y_hat - data.y(i)); } - } - std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" - << " mu RMSE: " << std::sqrt(mu_rmse_sum / n) << "\n" - << " tau RMSE (treated): " - << (n_treated > 0 ? std::sqrt(tau_rmse_sum / n_treated) : 0.0) << "\n" - << " sigma (last sample): " << std::sqrt(global_variance) << "\n" - << " sigma (truth): 0.5\n"; + std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" + << " mu RMSE: " << std::sqrt(mu_rmse_sum / n) << "\n" + << " tau RMSE (treated): " << std::sqrt(tau_rmse_sum / n) << "\n" + << " y RMSE: " << std::sqrt(y_rmse_sum / n) << "\n" + << " sigma (last sample): " << std::sqrt(global_variance) * y_std << "\n" + << " sigma (truth): 0.5\n"; + }; + + // Dispatch BCF sampler + run_bcf_sampler(n, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, dataset, residual, rng, + post_iter, report); } // ---- Main ----------------------------------------------------------- int main(int argc, char** argv) { int scenario = 0; - if (argc > 1) scenario = std::stoi(argv[1]); - - constexpr int n = 200, p = 5, num_trees = 200, num_gfr = 20, num_mcmc = 100; + int n = 500; + int p = 5; + int num_trees_mu = 200; + int num_trees_tau = 50; + int num_gfr = 20; + int num_mcmc = 100; + int seed = 1234; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if ((arg == "--scenario" || arg == "--n" || arg == "--p" || + arg == "--num_trees_mu" || arg == "--num_trees_tau" || arg == "--num_gfr" || arg == "--num_mcmc" || arg == "--seed") && + i + 1 < argc) { + int val = std::stoi(argv[++i]); + if (arg == "--scenario") + scenario = val; + else if (arg == "--n") + n = val; + else if (arg == "--p") + p = val; + else if (arg == "--num_trees_mu") + num_trees_mu = val; + else if (arg == "--num_trees_tau") + num_trees_tau = val; + else if (arg == "--num_gfr") + num_gfr = val; + else if (arg == "--num_mcmc") + num_mcmc = val; + else if (arg == "--seed") + seed = val; + } else { + std::cerr << "Unknown or incomplete argument: " << arg << "\n" + << "Usage: bcf_debug [--scenario N] [--n N] [--p N]" + " [--num_trees_mu N] [--num_trees_tau N] [--num_gfr N] [--num_mcmc N] [--seed N]\n"; + return 1; + } + } switch (scenario) { case 0: - run_scenario_0(n, p, num_trees, num_gfr, num_mcmc); + run_scenario_0(n, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); break; default: std::cerr << "Unknown scenario " << scenario - << ". Available scenarios: 0 (BasicBCF)\n"; + << ". Available scenarios: 0 (BCF: constant mu + univariate tau)\n"; return 1; } return 0; From d3d598ccc7275ad6414abfbdf29229b262f236a8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 15:44:12 -0400 Subject: [PATCH 15/64] Updated BART and BCF debug programs to use a test set for evaluations --- .vscode/launch.json | 16 +++++------ debug/bart_debug.cpp | 62 ++++++++++++++++++++++++++-------------- debug/bcf_debug.cpp | 68 ++++++++++++++++++++++++++++---------------- 3 files changed, 91 insertions(+), 55 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 814cf2dc..8ae6485b 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,7 +6,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["--scenario", "1", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (dev-quick)" }, @@ -15,7 +15,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build/bcf_debug", - "args": ["--scenario", "0", "--n", "200", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "0", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (dev-quick)" }, @@ -33,7 +33,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["--scenario", "1", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev-quick)" @@ -43,7 +43,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build/bcf_debug", - "args": ["--scenario", "0", "--n", "200", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "0", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev-quick)" @@ -63,7 +63,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bart_debug", - "args": ["--scenario", "1", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (release-drivers)" }, @@ -72,7 +72,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bart_debug", - "args": ["--scenario", "1", "--n", "500", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (release-drivers)" @@ -82,7 +82,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bcf_debug", - "args": ["--scenario", "0", "--n", "200", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "0", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (release-drivers)" }, @@ -91,7 +91,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bcf_debug", - "args": ["--scenario", "0", "--n", "200", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "0", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (release-drivers)" diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 91990e37..5f3b0bec 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -1,12 +1,16 @@ /* * BART debug program. The first CLI argument selects the scenario (default: 0). * - * Usage: bart_debug [--scenario N] [--n N] [--p N] [--num_trees N] + * Usage: bart_debug [--scenario N] [--n N] [--n_test N] [--p N] [--num_trees N] * [--num_gfr N] [--num_mcmc N] [--seed N] * * 0 Homoskedastic constant-leaf BART * DGP: y = sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + eps, eps ~ N(0,1) * + * 1 Homoskedastic constant-leaf probit BART + * DGP: Z = sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + eps, eps ~ N(0,1) + * y = 1{Z > 0} + * * Add scenarios here as the BARTSampler API develops (heteroskedastic, * random effects, multivariate leaf, etc.). */ @@ -96,9 +100,10 @@ ProbitDataset generate_probit_data(int n, int p, std::mt19937& rng) { using PostIterFn = std::function; using ReportFn = std::function&, double)>; -void run_bart_sampler(int n, int p, int num_trees, int num_gfr, int num_mcmc, +void run_bart_sampler(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, StochTree::ForestDataset& dataset, StochTree::ColumnVector& residual, std::mt19937& rng, + StochTree::ForestDataset& test_dataset, PostIterFn post_iter, ReportFn report_results) { // Single-threaded with default cutpoint grid size (for now) constexpr int num_threads = 1; @@ -160,13 +165,13 @@ void run_bart_sampler(int n, int p, int num_trees, int num_gfr, int num_mcmc, post_iter(tracker, global_variance); } - // Analyze posterior predictions (column-major, element [j*n + i] = sample j, obs i) - report_results(forest_samples.Predict(dataset), global_variance); + // Analyze posterior predictions (column-major, element [j*n_test + i] = sample j, obs i) + report_results(forest_samples.Predict(test_dataset), global_variance); } // ---- Scenario 0: homoskedastic constant-leaf BART ------------------- -void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { +void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { // Allow seed to be non-deterministic if set to sentinel value of -1 int rng_seed; if (seed == -1) { @@ -197,29 +202,34 @@ void run_scenario_0(int n, int p, int num_trees, int num_gfr, int num_mcmc, int global_variance = var_model.SampleVarianceParameter(residual.GetData(), a_sigma, b_sigma, rng); }; - // Lambda function for reporting RMSE and last draw of global error variance model + // Generate test data and build test dataset + RegressionDataset test_data = generate_constant_leaf_regression_data(n_test, p, rng); + StochTree::ForestDataset test_dataset; + test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); + + // Lambda function for reporting test-set RMSE and last draw of global error variance model auto report = [&](const std::vector& preds, double global_variance) { double rmse_sum = 0.0; - for (int i = 0; i < n; i++) { + for (int i = 0; i < n_test; i++) { double mu_hat = 0.0; for (int j = 0; j < num_mcmc; j++) - mu_hat += preds[static_cast(j * n + i)] / num_mcmc; - double err = (mu_hat * y_std + y_bar) - data.y(i); + mu_hat += preds[static_cast(j * n_test + i)] / num_mcmc; + double err = (mu_hat * y_std + y_bar) - test_data.y(i); rmse_sum += err * err; } std::cout << "\nScenario 0 (Homoskedastic BART):\n" - << " RMSE: " << std::sqrt(rmse_sum / n) << "\n" + << " RMSE (test): " << std::sqrt(rmse_sum / n_test) << "\n" << " sigma (last sample): " << std::sqrt(global_variance) * y_std << "\n" << " sigma (truth): 1.0\n"; }; // Dispatch BART sampler - run_bart_sampler(n, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, post_iter, report); + run_bart_sampler(n, n_test, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, test_dataset, post_iter, report); } // ---- Scenario 1: constant-leaf probit BART ------------------- -void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { +void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { // Allow seed to be non-deterministic if set to sentinel value of -1 int rng_seed; if (seed == -1) { @@ -247,23 +257,28 @@ void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int rng, y_vec.data(), tracker.GetSumPredictions(), residual.GetData().data(), y_bar, n); }; - // Lambda function for reporting RMSE + // Generate test data and build test dataset + ProbitDataset test_data = generate_probit_data(n_test, p, rng); + StochTree::ForestDataset test_dataset; + test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); + + // Lambda function for reporting test-set RMSE auto report = [&](const std::vector& preds, double global_variance) { double rmse_sum = 0.0; - for (int i = 0; i < n; i++) { + for (int i = 0; i < n_test; i++) { double mu_hat = 0.0; for (int j = 0; j < num_mcmc; j++) - mu_hat += preds[static_cast(j * n + i)] / num_mcmc; - double err = (mu_hat + y_bar) - data.Z(i); + mu_hat += preds[static_cast(j * n_test + i)] / num_mcmc; + double err = (mu_hat + y_bar) - test_data.Z(i); rmse_sum += err * err; } std::cout << "\nScenario 1 (Probit BART):\n" - << " RMSE: " << std::sqrt(rmse_sum / n) << "\n" + << " RMSE (test): " << std::sqrt(rmse_sum / n_test) << "\n" << " sigma (truth): 1.0\n"; }; // Dispatch BART sampler - run_bart_sampler(n, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, post_iter, report); + run_bart_sampler(n, n_test, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, test_dataset, post_iter, report); } // ---- Main ----------------------------------------------------------- @@ -271,6 +286,7 @@ void run_scenario_1(int n, int p, int num_trees, int num_gfr, int num_mcmc, int int main(int argc, char** argv) { int scenario = 1; int n = 500; + int n_test = 100; int p = 5; int num_trees = 200; int num_gfr = 10; @@ -279,7 +295,7 @@ int main(int argc, char** argv) { for (int i = 1; i < argc; ++i) { std::string arg = argv[i]; - if ((arg == "--scenario" || arg == "--n" || arg == "--p" || + if ((arg == "--scenario" || arg == "--n" || arg == "--n_test" || arg == "--p" || arg == "--num_trees" || arg == "--num_gfr" || arg == "--num_mcmc" || arg == "--seed") && i + 1 < argc) { int val = std::stoi(argv[++i]); @@ -287,6 +303,8 @@ int main(int argc, char** argv) { scenario = val; else if (arg == "--n") n = val; + else if (arg == "--n_test") + n_test = val; else if (arg == "--p") p = val; else if (arg == "--num_trees") @@ -299,7 +317,7 @@ int main(int argc, char** argv) { seed = val; } else { std::cerr << "Unknown or incomplete argument: " << arg << "\n" - << "Usage: bart_debug [--scenario N] [--n N] [--p N]" + << "Usage: bart_debug [--scenario N] [--n N] [--n_test N] [--p N]" " [--num_trees N] [--num_gfr N] [--num_mcmc N] [--seed N]\n"; return 1; } @@ -307,10 +325,10 @@ int main(int argc, char** argv) { switch (scenario) { case 0: - run_scenario_0(n, p, num_trees, num_gfr, num_mcmc, seed); + run_scenario_0(n, n_test, p, num_trees, num_gfr, num_mcmc, seed); break; case 1: - run_scenario_1(n, p, num_trees, num_gfr, num_mcmc, seed); + run_scenario_1(n, n_test, p, num_trees, num_gfr, num_mcmc, seed); break; default: std::cerr << "Unknown scenario " << scenario diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index 52d1ad60..b6134f36 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -1,7 +1,7 @@ /* * BCF debug program * - * Usage: bcf_debug [--scenario N] [--n N] [--p N] [--num_trees_mu N] [--num_trees_tau N] + * Usage: bcf_debug [--scenario N] [--n N] [--n_test N] [--p N] [--num_trees_mu N] [--num_trees_tau N] * [--num_gfr N] [--num_mcmc N] [--seed N] * * 0 Two-forest BCF: constant-leaf mu, univariate-leaf tau (Z as basis) @@ -82,9 +82,11 @@ BCFDataset generate_data(int n, int p, std::mt19937& rng) { using PostIterFn = std::function; using BCFReportFn = std::function&, const std::vector&, double)>; -void run_bcf_sampler(int n, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, +void run_bcf_sampler(int n, int n_test, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, StochTree::ForestDataset& dataset, StochTree::ColumnVector& residual, std::mt19937& rng, + StochTree::ForestDataset& test_dataset, + StochTree::ForestDataset& test_dataset_cate, PostIterFn post_iter, BCFReportFn report_results) { // Single-threaded with default cutpoint grid size (for now) constexpr int num_threads = 1; @@ -181,13 +183,14 @@ void run_bcf_sampler(int n, int p, int num_trees_mu, int num_trees_tau, int num_ post_iter(mu_tracker, global_variance); } - // Analyze posterior predictions (column-major, element [j*n + i] = sample j, obs i) - report_results(mu_samples.Predict(dataset), tau_samples.Predict(dataset), global_variance); + // Analyze posterior predictions (column-major, element [j*n_test + i] = sample j, obs i) + // tau uses test_dataset_cate (z=1 basis) so predictions == raw CATE estimates + report_results(mu_samples.Predict(test_dataset), tau_samples.Predict(test_dataset_cate), global_variance); } // ---- Scenario 0: constant-leaf mu + univariate-leaf tau (Z basis) --- -void run_scenario_0(int n, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, int seed = 42) { +void run_scenario_0(int n, int n_test, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, int seed = 42) { // Allow seed to be non-deterministic if set to sentinel value of -1 int rng_seed; if (seed == -1) { @@ -221,40 +224,52 @@ void run_scenario_0(int n, int p, int num_trees_mu, int num_trees_tau, int num_g global_variance = var_model.SampleVarianceParameter(residual.GetData(), a_sigma, b_sigma, rng); }; + // Generate test data and build test datasets + BCFDataset test_data = generate_data(n_test, p, rng); + Eigen::VectorXd z_ones = Eigen::VectorXd::Ones(n_test); + + // Test dataset: covariates + actual treatment z (for y prediction) + StochTree::ForestDataset test_dataset; + test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); + test_dataset.AddBasis(test_data.z.data(), n_test, /*num_col=*/1, /*row_major=*/false); + + // CATE dataset: covariates + z=1 so tau predictions == raw CATE estimates + StochTree::ForestDataset test_dataset_cate; + test_dataset_cate.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); + test_dataset_cate.AddBasis(z_ones.data(), n_test, /*num_col=*/1, /*row_major=*/false); + // Lambda function for reporting mu/tau RMSE and last draw of global error variance auto report = [&](const std::vector& mu_preds, const std::vector& tau_preds, double global_variance) { double mu_rmse_sum = 0.0, tau_rmse_sum = 0.0, y_rmse_sum = 0.0; - int n_treated = 0; - for (int i = 0; i < n; i++) { - double y_hat = 0.0; + for (int i = 0; i < n_test; i++) { double mu_hat = 0.0; for (int j = 0; j < num_mcmc; j++) - mu_hat += mu_preds[static_cast(j * n + i)] / num_mcmc; - mu_rmse_sum += (mu_hat * y_std + y_bar - data.mu_true(i)) * (mu_hat * y_std + y_bar - data.mu_true(i)); - y_hat += mu_hat * y_std + y_bar; + mu_hat += mu_preds[static_cast(j * n_test + i)] / num_mcmc; + mu_rmse_sum += (mu_hat * y_std + y_bar - test_data.mu_true(i)) * (mu_hat * y_std + y_bar - test_data.mu_true(i)); - // For z=1: tau_preds = tau_hat * 1 = tau_hat, so we can evaluate CATE directly - double tau_hat = 0.0; + // tau_preds from test_dataset_cate (z=1 basis) => raw CATE estimates + double cate_hat = 0.0; for (int j = 0; j < num_mcmc; j++) - tau_hat += tau_preds[static_cast(j * n + i)] / num_mcmc; - tau_rmse_sum += (tau_hat * y_std - data.tau_true(i)) * (tau_hat * y_std - data.tau_true(i)); - y_hat += tau_hat * data.z(i) * y_std; - y_rmse_sum += (y_hat - data.y(i)) * (y_hat - data.y(i)); + cate_hat += tau_preds[static_cast(j * n_test + i)] / num_mcmc; + tau_rmse_sum += (cate_hat * y_std - test_data.tau_true(i)) * (cate_hat * y_std - test_data.tau_true(i)); + + double y_hat = mu_hat * y_std + y_bar + cate_hat * test_data.z(i) * y_std; + y_rmse_sum += (y_hat - test_data.y(i)) * (y_hat - test_data.y(i)); } std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" - << " mu RMSE: " << std::sqrt(mu_rmse_sum / n) << "\n" - << " tau RMSE (treated): " << std::sqrt(tau_rmse_sum / n) << "\n" - << " y RMSE: " << std::sqrt(y_rmse_sum / n) << "\n" + << " mu RMSE (test): " << std::sqrt(mu_rmse_sum / n_test) << "\n" + << " tau RMSE (test): " << std::sqrt(tau_rmse_sum / n_test) << "\n" + << " y RMSE (test): " << std::sqrt(y_rmse_sum / n_test) << "\n" << " sigma (last sample): " << std::sqrt(global_variance) * y_std << "\n" << " sigma (truth): 0.5\n"; }; // Dispatch BCF sampler - run_bcf_sampler(n, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, dataset, residual, rng, - post_iter, report); + run_bcf_sampler(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, dataset, residual, rng, + test_dataset, test_dataset_cate, post_iter, report); } // ---- Main ----------------------------------------------------------- @@ -262,6 +277,7 @@ void run_scenario_0(int n, int p, int num_trees_mu, int num_trees_tau, int num_g int main(int argc, char** argv) { int scenario = 0; int n = 500; + int n_test = 100; int p = 5; int num_trees_mu = 200; int num_trees_tau = 50; @@ -271,7 +287,7 @@ int main(int argc, char** argv) { for (int i = 1; i < argc; ++i) { std::string arg = argv[i]; - if ((arg == "--scenario" || arg == "--n" || arg == "--p" || + if ((arg == "--scenario" || arg == "--n" || arg == "--n_test" || arg == "--p" || arg == "--num_trees_mu" || arg == "--num_trees_tau" || arg == "--num_gfr" || arg == "--num_mcmc" || arg == "--seed") && i + 1 < argc) { int val = std::stoi(argv[++i]); @@ -279,6 +295,8 @@ int main(int argc, char** argv) { scenario = val; else if (arg == "--n") n = val; + else if (arg == "--n_test") + n_test = val; else if (arg == "--p") p = val; else if (arg == "--num_trees_mu") @@ -293,7 +311,7 @@ int main(int argc, char** argv) { seed = val; } else { std::cerr << "Unknown or incomplete argument: " << arg << "\n" - << "Usage: bcf_debug [--scenario N] [--n N] [--p N]" + << "Usage: bcf_debug [--scenario N] [--n N] [--n_test N] [--p N]" " [--num_trees_mu N] [--num_trees_tau N] [--num_gfr N] [--num_mcmc N] [--seed N]\n"; return 1; } @@ -301,7 +319,7 @@ int main(int argc, char** argv) { switch (scenario) { case 0: - run_scenario_0(n, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); + run_scenario_0(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); break; default: std::cerr << "Unknown scenario " << scenario From 4fa70bbd3a56cbe2a794653cb65a88f7d9b028a2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 16:06:19 -0400 Subject: [PATCH 16/64] Added probit BCF --- debug/bcf_debug.cpp | 171 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 151 insertions(+), 20 deletions(-) diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index b6134f36..affcf66c 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -10,6 +10,13 @@ * z ~ Bernoulli(0.5) * y = mu(x) + tau(x)*z + N(0, 0.5^2) * + * 1 Two-forest BCF: constant-leaf mu, univariate-leaf tau (Z as basis) + * DGP: mu(x) = 2*sin(pi*x1) + 0.5*x2 + * tau(x) = 1 + x3 + * z ~ Bernoulli(0.5) + * W = mu(x) + tau(x)*z + N(0, 1) + * y = 1{W > 0} + * * Add scenarios here as the BCFSampler API develops (heteroskedastic, * random effects, propensity weighting, etc.). */ @@ -19,6 +26,7 @@ #include #include #include +#include #include #include @@ -34,7 +42,7 @@ static constexpr double kPi = 3.14159265358979323846; // ---- Data ------------------------------------------------------------ -struct BCFDataset { +struct SimpleBCFDataset { Eigen::Matrix X; Eigen::VectorXd y; Eigen::VectorXd z; @@ -42,12 +50,21 @@ struct BCFDataset { Eigen::VectorXd tau_true; }; -BCFDataset generate_data(int n, int p, std::mt19937& rng) { +struct ProbitBCFDataset { + Eigen::Matrix X; + Eigen::VectorXd y; + Eigen::VectorXd latent_outcome; + Eigen::VectorXd z; + Eigen::VectorXd mu_true; + Eigen::VectorXd tau_true; +}; + +SimpleBCFDataset generate_simple_bcf_data(int n, int p, std::mt19937& rng) { std::uniform_real_distribution unif(0.0, 1.0); std::normal_distribution normal(0.0, 1.0); std::bernoulli_distribution bern(0.5); - BCFDataset d; + SimpleBCFDataset d; d.X.resize(n, p); d.y.resize(n); d.z.resize(n); @@ -67,6 +84,33 @@ BCFDataset generate_data(int n, int p, std::mt19937& rng) { return d; } +ProbitBCFDataset generate_probit_bcf_data(int n, int p, std::mt19937& rng) { + std::uniform_real_distribution unif(0.0, 1.0); + std::normal_distribution normal(0.0, 1.0); + std::bernoulli_distribution bern(0.5); + + ProbitBCFDataset d; + d.X.resize(n, p); + d.y.resize(n); + d.z.resize(n); + d.mu_true.resize(n); + d.tau_true.resize(n); + d.latent_outcome.resize(n); + + for (int i = 0; i < n; i++) + for (int j = 0; j < p; j++) + d.X(i, j) = unif(rng); + + for (int i = 0; i < n; i++) { + d.z(i) = bern(rng) ? 1.0 : 0.0; + d.mu_true(i) = 2.0 * std::sin(kPi * d.X(i, 0)) + 0.5 * d.X(i, 1); + d.tau_true(i) = 1.0 + d.X(i, 2); + d.latent_outcome(i) = d.mu_true(i) + d.tau_true(i) * d.z(i) + normal(rng); + d.y(i) = (d.latent_outcome(i) > 0.0) ? 1.0 : 0.0; + } + return d; +} + // ---- Shared sampler loop -------------------------------------------- // // Runs alternating mu/tau GFR warmup then MCMC, sharing a single residual. @@ -79,14 +123,13 @@ BCFDataset generate_data(int n, int p, std::mt19937& rng) { // all samples are collected; receives column-major prediction matrices // and the final global variance value. -using PostIterFn = std::function; +using PostIterFn = std::function; using BCFReportFn = std::function&, const std::vector&, double)>; void run_bcf_sampler(int n, int n_test, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, StochTree::ForestDataset& dataset, StochTree::ColumnVector& residual, std::mt19937& rng, StochTree::ForestDataset& test_dataset, - StochTree::ForestDataset& test_dataset_cate, PostIterFn post_iter, BCFReportFn report_results) { // Single-threaded with default cutpoint grid size (for now) constexpr int num_threads = 1; @@ -125,6 +168,9 @@ void run_bcf_sampler(int n, int n_test, int p, int num_trees_mu, int num_trees_t UpdateResidualEntireForest(tau_tracker, dataset, residual, &tau_forest, false, std::minus()); tau_tracker.UpdatePredictions(&tau_forest, dataset); + // Model predictions + std::vector outcome_preds(n, 0.0); + // Initialize global error variance to 1 (output is standardized) double global_variance = 1.0; @@ -152,8 +198,13 @@ void run_bcf_sampler(int n, int n_test, int p, int num_trees_mu, int num_trees_t cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, /*backfitting=*/true, /*num_features_subsample=*/p, num_threads); + // Update predictions and residual for post-iteration hook (e.g. global variance sampling, probit data augmentation, etc.) + for (int j = 0; j < n; j++) { + outcome_preds[j] = mu_tracker.GetSamplePrediction(j) + tau_tracker.GetSamplePrediction(j); + } + // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) - post_iter(mu_tracker, global_variance); + post_iter(outcome_preds.data(), global_variance); } // Run MCMC @@ -179,13 +230,17 @@ void run_bcf_sampler(int n, int n_test, int p, int num_trees_mu, int num_trees_t /*keep_forest=*/true, /*pre_initialized=*/true, /*backfitting=*/true, num_threads); + // Update predictions and residual for post-iteration hook (e.g. global variance sampling, probit data augmentation, etc.) + for (int j = 0; j < n; j++) { + outcome_preds[j] = mu_tracker.GetSamplePrediction(j) + tau_tracker.GetSamplePrediction(j); + } + // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) - post_iter(mu_tracker, global_variance); + post_iter(outcome_preds.data(), global_variance); } // Analyze posterior predictions (column-major, element [j*n_test + i] = sample j, obs i) - // tau uses test_dataset_cate (z=1 basis) so predictions == raw CATE estimates - report_results(mu_samples.Predict(test_dataset), tau_samples.Predict(test_dataset_cate), global_variance); + report_results(mu_samples.Predict(test_dataset), tau_samples.PredictRaw(test_dataset), global_variance); } // ---- Scenario 0: constant-leaf mu + univariate-leaf tau (Z basis) --- @@ -202,7 +257,7 @@ void run_scenario_0(int n, int n_test, int p, int num_trees_mu, int num_trees_ta std::mt19937 rng(rng_seed); // Generate data and standardize outcome - BCFDataset data = generate_data(n, p, rng); + SimpleBCFDataset data = generate_simple_bcf_data(n, p, rng); double y_bar = data.y.mean(); double y_std = std::sqrt((data.y.array() - y_bar).square().mean()); Eigen::VectorXd resid_vec = (data.y.array() - y_bar) / y_std; // standardize @@ -220,24 +275,18 @@ void run_scenario_0(int n, int n_test, int p, int num_trees_mu, int num_trees_ta StochTree::GlobalHomoskedasticVarianceModel var_model; // Lambda function for sampling global error variance after each mu+tau step - auto post_iter = [&](StochTree::ForestTracker&, double& global_variance) { + auto post_iter = [&](double* outcome_preds, double& global_variance) { global_variance = var_model.SampleVarianceParameter(residual.GetData(), a_sigma, b_sigma, rng); }; // Generate test data and build test datasets - BCFDataset test_data = generate_data(n_test, p, rng); - Eigen::VectorXd z_ones = Eigen::VectorXd::Ones(n_test); + SimpleBCFDataset test_data = generate_simple_bcf_data(n_test, p, rng); // Test dataset: covariates + actual treatment z (for y prediction) StochTree::ForestDataset test_dataset; test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); test_dataset.AddBasis(test_data.z.data(), n_test, /*num_col=*/1, /*row_major=*/false); - // CATE dataset: covariates + z=1 so tau predictions == raw CATE estimates - StochTree::ForestDataset test_dataset_cate; - test_dataset_cate.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); - test_dataset_cate.AddBasis(z_ones.data(), n_test, /*num_col=*/1, /*row_major=*/false); - // Lambda function for reporting mu/tau RMSE and last draw of global error variance auto report = [&](const std::vector& mu_preds, const std::vector& tau_preds, double global_variance) { @@ -268,8 +317,87 @@ void run_scenario_0(int n, int n_test, int p, int num_trees_mu, int num_trees_ta }; // Dispatch BCF sampler - run_bcf_sampler(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, dataset, residual, rng, - test_dataset, test_dataset_cate, post_iter, report); + run_bcf_sampler(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, + dataset, residual, rng, test_dataset, post_iter, report); +} + +// ---- Scenario 1: constant-leaf mu + univariate-leaf tau (Z basis) with probit link --- + +void run_scenario_1(int n, int n_test, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, int seed = 42) { + // Allow seed to be non-deterministic if set to sentinel value of -1 + int rng_seed; + if (seed == -1) { + std::random_device rd; + rng_seed = rd(); + } else { + rng_seed = seed; + } + std::mt19937 rng(rng_seed); + + // Generate data and standardize outcome + ProbitBCFDataset data = generate_probit_bcf_data(n, p, rng); + double y_bar = StochTree::norm_cdf(data.y.mean()); + Eigen::VectorXd y_vec = data.y.array(); + Eigen::VectorXd Z_vec = (data.y.array() - y_bar); + + // Shared dataset: only tau forest uses the Z basis for leaf regression + StochTree::ForestDataset dataset; + dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + dataset.AddBasis(data.z.data(), n, /*num_col=*/1, /*row_major=*/false); + + // Shared residual + StochTree::ColumnVector residual(Z_vec.data(), n); + + // Global error variance model + constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior + StochTree::GlobalHomoskedasticVarianceModel var_model; + + // Lambda function for probit data augmentation sampling step (after each forest sample) + auto post_iter = [&](double* outcome_preds, double&) { + StochTree::sample_probit_latent_outcome( + rng, y_vec.data(), outcome_preds, residual.GetData().data(), y_bar, n); + }; + + // Generate test data and build test datasets + ProbitBCFDataset test_data = generate_probit_bcf_data(n_test, p, rng); + + // Test dataset: covariates + actual treatment z (for y prediction) + StochTree::ForestDataset test_dataset; + test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); + test_dataset.AddBasis(test_data.z.data(), n_test, /*num_col=*/1, /*row_major=*/false); + + // Lambda function for reporting mu/tau RMSE and last draw of global error variance + auto report = [&](const std::vector& mu_preds, const std::vector& tau_preds, + double global_variance) { + double mu_rmse_sum = 0.0, tau_rmse_sum = 0.0, y_rmse_sum = 0.0; + + for (int i = 0; i < n_test; i++) { + double mu_hat = 0.0; + for (int j = 0; j < num_mcmc; j++) + mu_hat += mu_preds[static_cast(j * n_test + i)] / num_mcmc; + mu_rmse_sum += (mu_hat + y_bar - test_data.mu_true(i)) * (mu_hat + y_bar - test_data.mu_true(i)); + + // tau_preds from test_dataset_cate (z=1 basis) => raw CATE estimates + double cate_hat = 0.0; + for (int j = 0; j < num_mcmc; j++) + cate_hat += tau_preds[static_cast(j * n_test + i)] / num_mcmc; + tau_rmse_sum += (cate_hat - test_data.tau_true(i)) * (cate_hat - test_data.tau_true(i)); + + double y_hat = mu_hat + y_bar + cate_hat * test_data.z(i); + y_rmse_sum += (y_hat - test_data.latent_outcome(i)) * (y_hat - test_data.latent_outcome(i)); + } + + std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" + << " mu RMSE (test): " << std::sqrt(mu_rmse_sum / n_test) << "\n" + << " tau RMSE (test): " << std::sqrt(tau_rmse_sum / n_test) << "\n" + << " latent outcome RMSE (test): " << std::sqrt(y_rmse_sum / n_test) << "\n" + << " sigma (last sample): " << std::sqrt(global_variance) << "\n" + << " sigma (truth): 1\n"; + }; + + // Dispatch BCF sampler + run_bcf_sampler(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, + dataset, residual, rng, test_dataset, post_iter, report); } // ---- Main ----------------------------------------------------------- @@ -321,6 +449,9 @@ int main(int argc, char** argv) { case 0: run_scenario_0(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); break; + case 1: + run_scenario_1(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); + break; default: std::cerr << "Unknown scenario " << scenario << ". Available scenarios: 0 (BCF: constant mu + univariate tau)\n"; From 0cca3842048739281013fb4b8cbfb5a4c026552c Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 18:06:23 -0400 Subject: [PATCH 17/64] Added linear regression and tests --- include/stochtree/linear_regression.h | 164 ++++++++++++++++++ test/cpp/test_linear_regression.cpp | 234 ++++++++++++++++++++++++++ 2 files changed, 398 insertions(+) create mode 100644 include/stochtree/linear_regression.h create mode 100644 test/cpp/test_linear_regression.cpp diff --git a/include/stochtree/linear_regression.h b/include/stochtree/linear_regression.h new file mode 100644 index 00000000..c9343f73 --- /dev/null +++ b/include/stochtree/linear_regression.h @@ -0,0 +1,164 @@ +/*! + * Copyright (c) 2026 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_REGRESSION_H_ +#define STOCHTREE_REGRESSION_H_ + +#include +#include +#include + +#include +#include "Eigen/src/Core/Matrix.h" + +namespace StochTree { + +/*! + * \brief Sample a regression coefficient from the posterior distribution of a univariate Gaussian regression model with known error variance and known prior variance. + * + * \param y Pointer to outcome array of length n + * \param x Pointer to covariate array of length n + * \param error_variance Known error variance (sigma^2) + * \param prior_variance Known prior variance (tau^2) + * \param n Number of observations + * \param gen Random number generator + * \return double + */ +inline double sample_univariate_gaussian_regression_coefficient(double* y, double* x, double error_variance, double prior_variance, int n, std::mt19937& gen) { + double sum_xx = 0.0; + double sum_yx = 0.0; + for (int i = 0; i < n; i++) { + sum_xx += x[i] * x[i]; + sum_yx += y[i] * x[i]; + } + double post_var = (prior_variance * error_variance) / (sum_xx * prior_variance + error_variance); + double post_mean = post_var * (sum_yx / error_variance); + return sample_standard_normal(post_mean, std::sqrt(post_var), gen); +} + +/*! + * \brief Sample regression coefficients from the posterior distribution of a bivariate Gaussian regression model with known error variance and known prior variance. + * + * \param output Pointer to sampled coefficient array of length 2, where the sampled coefficients will be stored + * \param y Pointer to outcome array of length n + * \param x1 Pointer to first covariate array of length n + * \param x2 Pointer to second covariate array of length n + * \param error_variance Known error variance (sigma^2) + * \param prior_variance_11 First diagonal element of the known prior covariance matrix (tau^2 for the first coefficient) + * \param prior_variance_12 Off-diagonal element of the known prior covariance matrix (covariance between the two coefficients) -- note that this is the same as prior_variance_21 so only one off-diagonal element is needed + * \param prior_variance_22 Second diagonal element of the known prior covariance matrix (tau^2 for the second coefficient) + * \param n Number of observations + * \param gen Random number generator + */ +inline void sample_general_bivariate_gaussian_regression_coefficients(double* output, double* y, double* x1, double* x2, double error_variance, double prior_variance_11, double prior_variance_12, double prior_variance_22, int n, std::mt19937& gen) { + double det_prior_var = prior_variance_11 * prior_variance_22 - prior_variance_12 * prior_variance_12; + double inv_prior_var_11 = prior_variance_22 / det_prior_var; + double inv_prior_var_12 = -prior_variance_12 / det_prior_var; + double inv_prior_var_22 = prior_variance_11 / det_prior_var; + double sum_x1x1 = 0.0; + double sum_x1x2 = 0.0; + double sum_x2x2 = 0.0; + double sum_yx1 = 0.0; + double sum_yx2 = 0.0; + for (int i = 0; i < n; i++) { + sum_x1x1 += x1[i] * x1[i]; + sum_x1x2 += x1[i] * x2[i]; + sum_x2x2 += x2[i] * x2[i]; + sum_yx1 += y[i] * x1[i]; + sum_yx2 += y[i] * x2[i]; + } + double post_var_pre_inv_11 = inv_prior_var_11 + sum_x1x1 / error_variance; + double post_var_pre_inv_12 = inv_prior_var_12 + sum_x1x2 / error_variance; + double post_var_pre_inv_22 = inv_prior_var_22 + sum_x2x2 / error_variance; + double det_post_var_pre_inv = post_var_pre_inv_11 * post_var_pre_inv_22 - post_var_pre_inv_12 * post_var_pre_inv_12; + double post_var_11 = post_var_pre_inv_22 / det_post_var_pre_inv; + double post_var_12 = -post_var_pre_inv_12 / det_post_var_pre_inv; + double post_var_22 = post_var_pre_inv_11 / det_post_var_pre_inv; + double post_mean_1 = post_var_11 * (sum_yx1 / error_variance) + post_var_12 * (sum_yx2 / error_variance); + double post_mean_2 = post_var_12 * (sum_yx1 / error_variance) + post_var_22 * (sum_yx2 / error_variance); + double chol_var_11 = std::sqrt(post_var_11); + double chol_var_12 = post_var_12 / chol_var_11; + double chol_var_22 = std::sqrt(post_var_22 - chol_var_12 * chol_var_12); + double z1 = sample_standard_normal(0.0, 1.0, gen); + double z2 = sample_standard_normal(0.0, 1.0, gen); + output[0] = post_mean_1 + chol_var_11 * z1; + output[1] = post_mean_2 + chol_var_12 * z1 + chol_var_22 * z2; +} + +/*! + * \brief Sample regression coefficients from the posterior distribution of a bivariate Gaussian regression model with known error variance and known diagonal prior variance. + * + * \param output Pointer to sampled coefficient array of length 2, where the sampled coefficients will be stored + * \param y Pointer to outcome array of length n + * \param x1 Pointer to first covariate array of length n + * \param x2 Pointer to second covariate array of length n + * \param error_variance Known error variance (sigma^2) + * \param prior_variance_11 First diagonal element of the known prior covariance matrix (tau^2 for the first coefficient) + * \param prior_variance_22 Second diagonal element of the known prior covariance matrix (tau^2 for the second coefficient) + * \param n Number of observations + * \param gen Random number generator + */ +inline void sample_diagonal_bivariate_gaussian_regression_coefficients(double* output, double* y, double* x1, double* x2, double error_variance, double prior_variance_11, double prior_variance_22, int n, std::mt19937& gen) { + double inv_prior_var_11 = 1.0 / prior_variance_11; + double inv_prior_var_22 = 1.0 / prior_variance_22; + double sum_x1x1 = 0.0; + double sum_x1x2 = 0.0; + double sum_x2x2 = 0.0; + double sum_yx1 = 0.0; + double sum_yx2 = 0.0; + for (int i = 0; i < n; i++) { + sum_x1x1 += x1[i] * x1[i]; + sum_x1x2 += x1[i] * x2[i]; + sum_x2x2 += x2[i] * x2[i]; + sum_yx1 += y[i] * x1[i]; + sum_yx2 += y[i] * x2[i]; + } + double post_var_pre_inv_11 = inv_prior_var_11 + sum_x1x1 / error_variance; + double post_var_pre_inv_12 = sum_x1x2 / error_variance; + double post_var_pre_inv_22 = inv_prior_var_22 + sum_x2x2 / error_variance; + double det_post_var_pre_inv = post_var_pre_inv_11 * post_var_pre_inv_22 - post_var_pre_inv_12 * post_var_pre_inv_12; + double post_var_11 = post_var_pre_inv_22 / det_post_var_pre_inv; + double post_var_12 = -post_var_pre_inv_12 / det_post_var_pre_inv; + double post_var_22 = post_var_pre_inv_11 / det_post_var_pre_inv; + double post_mean_1 = post_var_11 * (sum_yx1 / error_variance) + post_var_12 * (sum_yx2 / error_variance); + double post_mean_2 = post_var_12 * (sum_yx1 / error_variance) + post_var_22 * (sum_yx2 / error_variance); + double chol_var_11 = std::sqrt(post_var_11); + double chol_var_12 = post_var_12 / chol_var_11; + double chol_var_22 = std::sqrt(post_var_22 - chol_var_12 * chol_var_12); + double z1 = sample_standard_normal(0.0, 1.0, gen); + double z2 = sample_standard_normal(0.0, 1.0, gen); + output[0] = post_mean_1 + chol_var_11 * z1; + output[1] = post_mean_2 + chol_var_12 * z1 + chol_var_22 * z2; +} + +/*! + * \brief Sample regression coefficients from the posterior distribution of a bivariate Gaussian regression model with known error variance and known diagonal prior variance. + * + * \param y Eigen::VectorXd of outcomes of length n + * \param X Eigen::MatrixXd of covariates with n rows and p columns + * \param error_variance Known error variance (sigma^2) + * \param prior_variance Eigen::MatrixXd of known prior covariance matrix (tau^2 for the coefficients) of dimension p x p + * \param n Number of observations + * \param gen Random number generator + */ +Eigen::VectorXd sample_general_gaussian_regression_coefficients(Eigen::VectorXd& y, Eigen::MatrixXd& X, double error_variance, Eigen::MatrixXd& prior_variance, int n, std::mt19937& gen) { + int p = X.cols(); + Eigen::MatrixXd inv_prior_var = prior_variance.inverse(); + Eigen::MatrixXd XtX = X.transpose() * X; + Eigen::VectorXd Xty = X.transpose() * y; + Eigen::MatrixXd post_var_pre_inv = inv_prior_var + XtX / error_variance; + Eigen::MatrixXd post_var = post_var_pre_inv.inverse(); + Eigen::VectorXd post_mean = post_var * (Xty / error_variance); + Eigen::LLT chol(post_var); + Eigen::MatrixXd L = chol.matrixL(); + Eigen::VectorXd z(p); + for (int i = 0; i < p; i++) { + z(i) = sample_standard_normal(0.0, 1.0, gen); + } + return post_mean + L * z; +} + +} // namespace StochTree + +#endif // STOCHTREE_REGRESSION_H_ diff --git a/test/cpp/test_linear_regression.cpp b/test/cpp/test_linear_regression.cpp new file mode 100644 index 00000000..44d82c50 --- /dev/null +++ b/test/cpp/test_linear_regression.cpp @@ -0,0 +1,234 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "Eigen/src/Core/Matrix.h" + +TEST(LinearRegression, UnivariateDegeneratePosteriorMeanCorrectness) { + // Test that the posterior mean of the regression coefficient is correct in a degenerate case where the outcome has no variance + // and the prior variance is nearly infinite (i.e. the posterior mean should equal the OLS estimate). + + // Generate data + std::mt19937 gen(1234); + int n = 100; + std::vector x(n, 0.0); + std::vector y(n, 0.0); + for (int i = 0; i < n; i++) { + x[i] = StochTree::standard_uniform_draw_53bit(gen); + y[i] = 2.0 * x[i]; + } + double sigma2 = 1e-6; // near-zero outcome variance + double tau2 = 1e6; // near-infinite prior variance + + // Sample from the regression model + int num_samples = 1000; + std::vector beta_samples(num_samples); + for (int i = 0; i < num_samples; i++) { + beta_samples[i] = StochTree::sample_univariate_gaussian_regression_coefficient(y.data(), x.data(), sigma2, tau2, n, gen); + } + + // Check posterior mean is close to true value (which should also be the OLS estimate without noise) + double beta_mean = std::accumulate(beta_samples.begin(), beta_samples.end(), 0.0) / num_samples; + double ols_estimate = 2.0; + EXPECT_NEAR(beta_mean, ols_estimate, 1e-2); +} + +TEST(LinearRegression, UnivariatePosteriorMeanCorrectness) { + // Test that the sampled regression coefficients average out close to the expected posterior mean with enough samples + + // Generate data + std::mt19937 gen(1234); + int n = 100; + std::vector x(n, 0.0); + std::vector y(n, 0.0); + for (int i = 0; i < n; i++) { + x[i] = StochTree::standard_uniform_draw_53bit(gen); + y[i] = 2.0 * x[i] + StochTree::sample_standard_normal(0.0, 0.1, gen); + } + double sigma2 = 1; + double tau2 = 1; + + // Compute the "true" posterior mean analytically for comparison + double sum_xx = 0.0; + double sum_yx = 0.0; + for (int i = 0; i < n; i++) { + sum_xx += x[i] * x[i]; + sum_yx += y[i] * x[i]; + } + double post_mean = (sum_yx * sigma2) / (sum_xx * tau2 + sigma2); + + // Draw many samples from the posterior and compute their average + int num_samples = 10000; + std::vector beta_samples(num_samples); + for (int i = 0; i < num_samples; i++) { + beta_samples[i] = StochTree::sample_univariate_gaussian_regression_coefficient(y.data(), x.data(), sigma2, tau2, n, gen); + } + double beta_mean = std::accumulate(beta_samples.begin(), beta_samples.end(), 0.0) / num_samples; + EXPECT_NEAR(beta_mean, post_mean, 1e-2); +} + +TEST(LinearRegression, BivariatePosteriorMeanCorrectness) { + // Test that the sampled regression coefficients average out close to the expected posterior mean with enough samples + + // Generate data + std::mt19937 gen(1234); + int n = 100; + std::vector x1(n, 0.0); + std::vector x2(n, 0.0); + std::vector y(n, 0.0); + for (int i = 0; i < n; i++) { + x1[i] = StochTree::standard_uniform_draw_53bit(gen); + x2[i] = StochTree::standard_uniform_draw_53bit(gen); + y[i] = 2.0 * x1[i] + 3.0 * x2[i] + StochTree::sample_standard_normal(0.0, 0.1, gen); + } + double sigma2 = 1; + double prior_variance_11 = 1; + double prior_variance_12 = 0.5; + double prior_variance_22 = 1; + + // Compute the "true" posterior mean analytically for comparison + double det_prior_var = prior_variance_11 * prior_variance_22 - prior_variance_12 * prior_variance_12; + double inv_prior_var_11 = prior_variance_22 / det_prior_var; + double inv_prior_var_12 = -prior_variance_12 / det_prior_var; + double inv_prior_var_22 = prior_variance_11 / det_prior_var; + double sum_x1x1 = 0.0; + double sum_x2x2 = 0.0; + double sum_x1x2 = 0.0; + double sum_yx1 = 0.0; + double sum_yx2 = 0.0; + for (int i = 0; i < n; i++) { + sum_x1x1 += x1[i] * x1[i]; + sum_x2x2 += x2[i] * x2[i]; + sum_x1x2 += x1[i] * x2[i]; + sum_yx1 += y[i] * x1[i]; + sum_yx2 += y[i] * x2[i]; + } + double post_var_pre_inv_11 = inv_prior_var_11 + sum_x1x1 / sigma2; + double post_var_pre_inv_12 = inv_prior_var_12 + sum_x1x2 / sigma2; + double post_var_pre_inv_22 = inv_prior_var_22 + sum_x2x2 / sigma2; + double det_post_var_pre_inv = post_var_pre_inv_11 * post_var_pre_inv_22 - post_var_pre_inv_12 * post_var_pre_inv_12; + double post_var_11 = post_var_pre_inv_22 / det_post_var_pre_inv; + double post_var_12 = -post_var_pre_inv_12 / det_post_var_pre_inv; + double post_var_22 = post_var_pre_inv_11 / det_post_var_pre_inv; + double post_mean_1 = post_var_11 * (sum_yx1 / sigma2) + post_var_12 * (sum_yx2 / sigma2); + double post_mean_2 = post_var_12 * (sum_yx1 / sigma2) + post_var_22 * (sum_yx2 / sigma2); + + // Draw many samples from the posterior and compute their average + int num_samples = 10000; + double beta_mean_1_sum = 0.0; + double beta_mean_2_sum = 0.0; + std::vector beta_samples(num_samples * 2); + for (int i = 0; i < num_samples; i++) { + StochTree::sample_general_bivariate_gaussian_regression_coefficients(beta_samples.data() + 2 * i, y.data(), x1.data(), x2.data(), sigma2, prior_variance_11, prior_variance_12, prior_variance_22, n, gen); + beta_mean_1_sum += beta_samples[2 * i]; + beta_mean_2_sum += beta_samples[2 * i + 1]; + } + double beta_mean_1 = beta_mean_1_sum / num_samples; + double beta_mean_2 = beta_mean_2_sum / num_samples; + EXPECT_NEAR(beta_mean_1, post_mean_1, 1e-2); + EXPECT_NEAR(beta_mean_2, post_mean_2, 1e-2); +} + +TEST(LinearRegression, BivariateMatchWhenDiagonalPrior) { + // Test that the sampled regression coefficients for the general bivariate and specialized diagonal bivariate samplers are close to each other with enough samples when the covariance is diagonal + + // Generate data + std::mt19937 gen(1234); + int n = 100; + std::vector x1(n, 0.0); + std::vector x2(n, 0.0); + std::vector y(n, 0.0); + for (int i = 0; i < n; i++) { + x1[i] = StochTree::standard_uniform_draw_53bit(gen); + x2[i] = StochTree::standard_uniform_draw_53bit(gen); + y[i] = 2.0 * x1[i] + 3.0 * x2[i] + StochTree::sample_standard_normal(0.0, 0.1, gen); + } + double sigma2 = 1; + double prior_variance_11 = 1; + double prior_variance_12 = 0; + double prior_variance_22 = 1; + + // Draw many samples from the posterior and compute their average + int num_samples = 10000; + double beta_mean_1_sum_general = 0.0; + double beta_mean_2_sum_general = 0.0; + double beta_mean_1_sum_diagonal = 0.0; + double beta_mean_2_sum_diagonal = 0.0; + std::vector beta_samples_general(num_samples * 2); + std::vector beta_samples_diagonal(num_samples * 2); + for (int i = 0; i < num_samples; i++) { + StochTree::sample_general_bivariate_gaussian_regression_coefficients(beta_samples_general.data() + 2 * i, y.data(), x1.data(), x2.data(), sigma2, prior_variance_11, prior_variance_12, prior_variance_22, n, gen); + StochTree::sample_diagonal_bivariate_gaussian_regression_coefficients(beta_samples_diagonal.data() + 2 * i, y.data(), x1.data(), x2.data(), sigma2, prior_variance_11, prior_variance_22, n, gen); + beta_mean_1_sum_general += beta_samples_general[2 * i]; + beta_mean_2_sum_general += beta_samples_general[2 * i + 1]; + beta_mean_1_sum_diagonal += beta_samples_diagonal[2 * i]; + beta_mean_2_sum_diagonal += beta_samples_diagonal[2 * i + 1]; + } + double beta_mean_1_general = beta_mean_1_sum_general / num_samples; + double beta_mean_2_general = beta_mean_2_sum_general / num_samples; + double beta_mean_1_diagonal = beta_mean_1_sum_diagonal / num_samples; + double beta_mean_2_diagonal = beta_mean_2_sum_diagonal / num_samples; + EXPECT_NEAR(beta_mean_1_general, beta_mean_1_diagonal, 1e-2); + EXPECT_NEAR(beta_mean_2_general, beta_mean_2_diagonal, 1e-2); +} + +TEST(LinearRegression, MultivariateBivariateMatch) { + // Test that the sampled regression coefficients for the general bivariate and multivariate samplers are close to each other with enough samples when covariates are bivariate + + // Generate data + std::mt19937 gen(1234); + int n = 100; + std::vector x1(n, 0.0); + std::vector x2(n, 0.0); + std::vector y(n, 0.0); + Eigen::VectorXd y_eigen(n); + Eigen::MatrixXd X_eigen(n, 2); + double x1_elem, x2_elem, y_elem; + for (int i = 0; i < n; i++) { + x1_elem = StochTree::standard_uniform_draw_53bit(gen); + x2_elem = StochTree::standard_uniform_draw_53bit(gen); + y_elem = 2.0 * x1_elem + 3.0 * x2_elem + StochTree::sample_standard_normal(0.0, 0.1, gen); + x1[i] = x1_elem; + x2[i] = x2_elem; + y[i] = y_elem; + y_eigen(i) = y_elem; + X_eigen(i, 0) = x1_elem; + X_eigen(i, 1) = x2_elem; + } + double sigma2 = 1; + double prior_variance_11 = 1; + double prior_variance_12 = 0; + double prior_variance_22 = 1; + Eigen::MatrixXd prior_variance(2, 2); + prior_variance(0, 0) = prior_variance_11; + prior_variance(0, 1) = prior_variance_12; + prior_variance(1, 0) = prior_variance_12; + prior_variance(1, 1) = prior_variance_22; + + // Draw many samples from the posterior and compute their average + int num_samples = 10000; + double beta_mean_1_sum_bivariate = 0.0; + double beta_mean_2_sum_bivariate = 0.0; + double beta_mean_1_sum_multivariate = 0.0; + double beta_mean_2_sum_multivariate = 0.0; + std::vector beta_samples_bivariate(num_samples * 2); + Eigen::VectorXd beta(2); + for (int i = 0; i < num_samples; i++) { + StochTree::sample_general_bivariate_gaussian_regression_coefficients(beta_samples_bivariate.data() + 2 * i, y.data(), x1.data(), x2.data(), sigma2, prior_variance_11, prior_variance_12, prior_variance_22, n, gen); + beta = StochTree::sample_general_gaussian_regression_coefficients(y_eigen, X_eigen, sigma2, prior_variance, n, gen); + beta_mean_1_sum_bivariate += beta_samples_bivariate[2 * i]; + beta_mean_2_sum_bivariate += beta_samples_bivariate[2 * i + 1]; + beta_mean_1_sum_multivariate += beta(0); + beta_mean_2_sum_multivariate += beta(1); + } + double beta_mean_1_bivariate = beta_mean_1_sum_bivariate / num_samples; + double beta_mean_2_bivariate = beta_mean_2_sum_bivariate / num_samples; + double beta_mean_1_multivariate = beta_mean_1_sum_multivariate / num_samples; + double beta_mean_2_multivariate = beta_mean_2_sum_multivariate / num_samples; + EXPECT_NEAR(beta_mean_1_bivariate, beta_mean_1_multivariate, 1e-2); + EXPECT_NEAR(beta_mean_2_bivariate, beta_mean_2_multivariate, 1e-2); +} From c003b71af23d0d1fd661f079c3883b80b8df2dc9 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 18:08:21 -0400 Subject: [PATCH 18/64] Formatted C++ test code --- test/cpp/test_category_tracker.cpp | 9 +- test/cpp/test_cutpoints.cpp | 1 - test/cpp/test_data.cpp | 11 +- test/cpp/test_forest.cpp | 47 +- test/cpp/test_json.cpp | 34 +- test/cpp/test_linear_regression.cpp | 1 - test/cpp/test_predict.cpp | 21 +- test/cpp/test_random_effects.cpp | 54 ++- test/cpp/test_sorted_partition_tracker.cpp | 25 +- test/cpp/test_tree.cpp | 26 +- test/cpp/test_unsorted_partition_tracker.cpp | 14 +- test/cpp/testutils.cpp | 444 +++++++++---------- test/cpp/testutils.h | 4 +- 13 files changed, 341 insertions(+), 350 deletions(-) diff --git a/test/cpp/test_category_tracker.cpp b/test/cpp/test_category_tracker.cpp index 05f28b2a..4ef249be 100644 --- a/test/cpp/test_category_tracker.cpp +++ b/test/cpp/test_category_tracker.cpp @@ -5,22 +5,19 @@ #include #include #include -#include -#include #include TEST(CategorySampleTracker, BasicOperations) { // Create a vector of categorical data - std::vector category_data { - 3, 4, 3, 2, 2, 4, 3, 3, 3, 4, 3, 4 - }; + std::vector category_data{ + 3, 4, 3, 2, 2, 4, 3, 3, 3, 4, 3, 4}; // Create a CategorySamplerTracker StochTree::CategorySampleTracker category_tracker = StochTree::CategorySampleTracker(category_data); // Extract the label map std::map label_map = category_tracker.GetLabelMap(); - std::map expected_label_map {{2, 0}, {3, 1}, {4, 2}}; + std::map expected_label_map{{2, 0}, {3, 1}, {4, 2}}; // Check that the map was constructed as expected ASSERT_EQ(label_map[2], 0); diff --git a/test/cpp/test_cutpoints.cpp b/test/cpp/test_cutpoints.cpp index 5e71a498..42fbe388 100644 --- a/test/cpp/test_cutpoints.cpp +++ b/test/cpp/test_cutpoints.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include diff --git a/test/cpp/test_data.cpp b/test/cpp/test_data.cpp index 62215cff..18a23a08 100644 --- a/test/cpp/test_data.cpp +++ b/test/cpp/test_data.cpp @@ -8,8 +8,6 @@ #include #include #include -#include -#include TEST(Data, ReadFromSmallDatasetRowMajor) { // Load test data @@ -25,14 +23,14 @@ TEST(Data, ReadFromSmallDatasetRowMajor) { dataset.AddCovariates(test_dataset.covariates.data(), n, test_dataset.x_cols, test_dataset.row_major); dataset.AddBasis(test_dataset.omega.data(), test_dataset.n, test_dataset.omega_cols, test_dataset.row_major); StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), n); - + // Compute average value for each feature, compared to their known values std::vector total; std::vector average; total.resize(p, 0.); average.resize(p, 0.); for (int j = 0; j < p; j++) { - for (data_size_t i = 0; i < n; i++) { + for (data_size_t i = 0; i < n; i++) { total[j] += dataset.CovariateValue(i, j); } } @@ -60,14 +58,14 @@ TEST(Data, ReadFromMediumDatasetRowMajor) { dataset.AddCovariates(test_dataset.covariates.data(), n, test_dataset.x_cols, test_dataset.row_major); dataset.AddBasis(test_dataset.omega.data(), test_dataset.n, test_dataset.omega_cols, test_dataset.row_major); StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), n); - + // Compute average value for each feature, compared to their known values std::vector total; std::vector average; total.resize(p, 0.); average.resize(p, 0.); for (int j = 0; j < p; j++) { - for (data_size_t i = 0; i < n; i++) { + for (data_size_t i = 0; i < n; i++) { total[j] += dataset.CovariateValue(i, j); } } @@ -80,4 +78,3 @@ TEST(Data, ReadFromMediumDatasetRowMajor) { EXPECT_NEAR(0.4863596, average[3], 0.0001); EXPECT_NEAR(0.4413101, average[4], 0.0001); } - diff --git a/test/cpp/test_forest.cpp b/test/cpp/test_forest.cpp index d870997c..cabc3289 100644 --- a/test/cpp/test_forest.cpp +++ b/test/cpp/test_forest.cpp @@ -8,8 +8,6 @@ #include #include #include -#include -#include TEST(Forest, UnivariateForestConstruction) { int num_trees = 2; @@ -45,13 +43,13 @@ TEST(Forest, UnivariateForestMerge) { dataset.AddCovariates(test_dataset.covariates.data(), n, test_dataset.x_cols, test_dataset.row_major); dataset.AddBasis(test_dataset.omega.data(), test_dataset.n, test_dataset.omega_cols, test_dataset.row_major); StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), n); - + // Create a small ensemble int output_dim = 1; int num_trees = 2; bool is_leaf_constant = true; StochTree::TreeEnsemble ensemble1(num_trees, output_dim, is_leaf_constant); - + // Create another small ensemble StochTree::TreeEnsemble ensemble2(num_trees, output_dim, is_leaf_constant); @@ -63,8 +61,8 @@ TEST(Forest, UnivariateForestMerge) { tree->ExpandNode(0, 1, tree_split, -2.5, 2.5); // Run predict on the supplied covariates and check the result for the first forest - std::vector result(n*output_dim); - std::vector expected_pred = {7.5,2.5,-7.5,7.5,7.5,7.5,2.5,7.5,-7.5,-2.5}; + std::vector result(n * output_dim); + std::vector expected_pred = {7.5, 2.5, -7.5, 7.5, 7.5, 7.5, 2.5, 7.5, -7.5, -2.5}; ensemble1.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -79,8 +77,8 @@ TEST(Forest, UnivariateForestMerge) { tree->ExpandNode(0, 2, tree_split, -0.5, 0.5); // Run predict on the supplied covariates and check the result for the second forest - result = std::vector(n*output_dim); - expected_pred = std::vector{1.5,-1.5,-0.5,-0.5,1.5,1.5,-1.5,1.5,-0.5,-1.5}; + result = std::vector(n * output_dim); + expected_pred = std::vector{1.5, -1.5, -0.5, -0.5, 1.5, 1.5, -1.5, 1.5, -0.5, -1.5}; ensemble2.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -88,8 +86,8 @@ TEST(Forest, UnivariateForestMerge) { // Merge the second forest into the first ensemble1.MergeForest(ensemble2); - result = std::vector(n*output_dim); - expected_pred = std::vector{9,1,-8,7,9,9,1,9,-8,-4}; + result = std::vector(n * output_dim); + expected_pred = std::vector{9, 1, -8, 7, 9, 9, 1, 9, -8, -4}; ensemble1.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -108,13 +106,13 @@ TEST(Forest, UnivariateForestAdd) { dataset.AddCovariates(test_dataset.covariates.data(), n, test_dataset.x_cols, test_dataset.row_major); dataset.AddBasis(test_dataset.omega.data(), test_dataset.n, test_dataset.omega_cols, test_dataset.row_major); StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), n); - + // Create a small ensemble int output_dim = 1; int num_trees = 2; bool is_leaf_constant = true; StochTree::TreeEnsemble ensemble1(num_trees, output_dim, is_leaf_constant); - + // Create another small ensemble StochTree::TreeEnsemble ensemble2(num_trees, output_dim, is_leaf_constant); @@ -126,8 +124,8 @@ TEST(Forest, UnivariateForestAdd) { tree->ExpandNode(0, 1, tree_split, -2.5, 2.5); // Run predict on the supplied covariates and check the result for the first forest - std::vector result(n*output_dim); - std::vector expected_pred = {7.5,2.5,-7.5,7.5,7.5,7.5,2.5,7.5,-7.5,-2.5}; + std::vector result(n * output_dim); + std::vector expected_pred = {7.5, 2.5, -7.5, 7.5, 7.5, 7.5, 2.5, 7.5, -7.5, -2.5}; ensemble1.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -137,8 +135,8 @@ TEST(Forest, UnivariateForestAdd) { ensemble1.AddValueToLeaves(1.0); // Run predict on the supplied covariates and check the result for the first forest - result = std::vector(n*output_dim); - expected_pred = std::vector{9.5,4.5,-5.5,9.5,9.5,9.5,4.5,9.5,-5.5,-0.5}; + result = std::vector(n * output_dim); + expected_pred = std::vector{9.5, 4.5, -5.5, 9.5, 9.5, 9.5, 4.5, 9.5, -5.5, -0.5}; ensemble1.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -153,8 +151,8 @@ TEST(Forest, UnivariateForestAdd) { tree->ExpandNode(0, 2, tree_split, -0.5, 0.5); // Run predict on the supplied covariates and check the result for the second forest - result = std::vector(n*output_dim); - expected_pred = std::vector{1.5,-1.5,-0.5,-0.5,1.5,1.5,-1.5,1.5,-0.5,-1.5}; + result = std::vector(n * output_dim); + expected_pred = std::vector{1.5, -1.5, -0.5, -0.5, 1.5, 1.5, -1.5, 1.5, -0.5, -1.5}; ensemble2.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -164,8 +162,8 @@ TEST(Forest, UnivariateForestAdd) { ensemble2.AddValueToLeaves(-1.0); // Run predict on the supplied covariates and check the result for the first forest - result = std::vector(n*output_dim); - expected_pred = std::vector{-0.5,-3.5,-2.5,-2.5,-0.5,-0.5,-3.5,-0.5,-2.5,-3.5}; + result = std::vector(n * output_dim); + expected_pred = std::vector{-0.5, -3.5, -2.5, -2.5, -0.5, -0.5, -3.5, -0.5, -2.5, -3.5}; ensemble2.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -173,8 +171,8 @@ TEST(Forest, UnivariateForestAdd) { // Merge the second forest into the first ensemble1.MergeForest(ensemble2); - result = std::vector(n*output_dim); - expected_pred = std::vector{9,1,-8,7,9,9,1,9,-8,-4}; + result = std::vector(n * output_dim); + expected_pred = std::vector{9, 1, -8, 7, 9, 9, 1, 9, -8, -4}; ensemble1.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -182,11 +180,10 @@ TEST(Forest, UnivariateForestAdd) { // Merge the second forest into the first ensemble1.MultiplyLeavesByValue(2.0); - result = std::vector(n*output_dim); - expected_pred = std::vector{18,2,-16,14,18,18,2,18,-16,-8}; + result = std::vector(n * output_dim); + expected_pred = std::vector{18, 2, -16, 14, 18, 18, 2, 18, -16, -8}; ensemble1.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); } } - diff --git a/test/cpp/test_json.cpp b/test/cpp/test_json.cpp index 77a28faf..ef2bdfbf 100644 --- a/test/cpp/test_json.cpp +++ b/test/cpp/test_json.cpp @@ -7,15 +7,13 @@ #include #include #include -#include -#include TEST(Json, TreeUnivariateLeaf) { // Initialize tree StochTree::Tree tree; StochTree::TreeSplit split; tree.Init(1); - + // Perform three splits split = StochTree::TreeSplit(0.5); tree.ExpandNode(0, 0, split, 0., 0.); @@ -23,10 +21,10 @@ TEST(Json, TreeUnivariateLeaf) { tree.ExpandNode(1, 1, split, 0., 0.); split = StochTree::TreeSplit(0.6); tree.ExpandNode(3, 2, split, 0., 0.); - + // Prune node 3 to a leaf tree.CollapseToLeaf(3, 0.); - + // Write to json nlohmann::json tree_json = tree.to_json(); @@ -43,20 +41,20 @@ TEST(Json, TreeUnivariateLeafCategoricalSplit) { StochTree::Tree tree; StochTree::TreeSplit split; tree.Init(1); - + // Perform three splits - std::vector split_categories_1{1,3,5,7}; + std::vector split_categories_1{1, 3, 5, 7}; split = StochTree::TreeSplit(split_categories_1); tree.ExpandNode(0, 0, split, 0., 0.); - std::vector split_categories_2{2,3,5}; + std::vector split_categories_2{2, 3, 5}; split = StochTree::TreeSplit(split_categories_2); tree.ExpandNode(1, 1, split, 0., 0.); split = StochTree::TreeSplit(0.6); tree.ExpandNode(3, 2, split, 0., 0.); - + // Prune node 3 to a leaf tree.CollapseToLeaf(3, 0.); - + // Write to json nlohmann::json tree_json = tree.to_json(); @@ -80,7 +78,7 @@ TEST(Json, TreeMultivariateLeaf) { std::vector leaf_values5(tree_dim, 345235636.4); std::vector leaf_values6(tree_dim, 10023.1); tree.Init(tree_dim); - + // Perform three splits split = StochTree::TreeSplit(0.5); tree.ExpandNode(0, 0, split, leaf_values1, leaf_values2); @@ -88,10 +86,10 @@ TEST(Json, TreeMultivariateLeaf) { tree.ExpandNode(1, 1, split, leaf_values3, leaf_values4); split = StochTree::TreeSplit(0.6); tree.ExpandNode(1, 1, split, leaf_values5, leaf_values6); - + // Prune node 3 to a leaf tree.CollapseToLeaf(3, leaf_values3); - + // Write to json nlohmann::json tree_json = tree.to_json(); @@ -116,20 +114,20 @@ TEST(Json, TreeMultivariateLeafCategoricalSplit) { std::vector leaf_values5(tree_dim, 345235636.4); std::vector leaf_values6(tree_dim, 10023.1); tree.Init(tree_dim); - + // Perform three splits - std::vector split_categories_1{1,3,5,7}; + std::vector split_categories_1{1, 3, 5, 7}; split = StochTree::TreeSplit(split_categories_1); tree.ExpandNode(0, 0, split, leaf_values1, leaf_values2); - std::vector split_categories_2{2,3,5}; + std::vector split_categories_2{2, 3, 5}; split = StochTree::TreeSplit(split_categories_2); tree.ExpandNode(1, 1, split, leaf_values3, leaf_values4); split = StochTree::TreeSplit(0.6); tree.ExpandNode(1, 1, split, leaf_values5, leaf_values6); - + // Prune node 3 to a leaf tree.CollapseToLeaf(3, leaf_values3); - + // Write to json nlohmann::json tree_json = tree.to_json(); diff --git a/test/cpp/test_linear_regression.cpp b/test/cpp/test_linear_regression.cpp index 44d82c50..92ee8610 100644 --- a/test/cpp/test_linear_regression.cpp +++ b/test/cpp/test_linear_regression.cpp @@ -6,7 +6,6 @@ #include #include #include -#include "Eigen/src/Core/Matrix.h" TEST(LinearRegression, UnivariateDegeneratePosteriorMeanCorrectness) { // Test that the posterior mean of the regression coefficient is correct in a degenerate case where the outcome has no variance diff --git a/test/cpp/test_predict.cpp b/test/cpp/test_predict.cpp index 8f35adfa..ed72ae17 100644 --- a/test/cpp/test_predict.cpp +++ b/test/cpp/test_predict.cpp @@ -1,14 +1,10 @@ -/*! - * Test of the ensemble prediction method - */ +/*! Test of the ensemble prediction method */ #include #include #include #include #include #include -#include -#include /*! \brief Test forest prediction procedures for trees with constants in leaf nodes */ TEST(Ensemble, PredictConstant) { @@ -23,7 +19,7 @@ TEST(Ensemble, PredictConstant) { dataset.AddCovariates(test_dataset.covariates.data(), n, test_dataset.x_cols, test_dataset.row_major); dataset.AddBasis(test_dataset.omega.data(), test_dataset.n, test_dataset.omega_cols, test_dataset.row_major); StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), n); - + // Create a small ensemble int output_dim = 1; int num_trees = 2; @@ -38,8 +34,8 @@ TEST(Ensemble, PredictConstant) { tree->ExpandNode(0, 1, tree_split, -2.5, 2.5); // Run predict on the supplied covariates and check the result - std::vector result(n*output_dim); - std::vector expected_pred = {7.5,2.5,-7.5,7.5,7.5,7.5,2.5,7.5,-7.5,-2.5}; + std::vector result(n * output_dim); + std::vector expected_pred = {7.5, 2.5, -7.5, 7.5, 7.5, 7.5, 2.5, 7.5, -7.5, -2.5}; ensemble.PredictInplace(dataset.GetCovariates(), result, 0); for (int i = 0; i < n; i++) { ASSERT_NEAR(expected_pred[i], result[i], 0.01); @@ -59,7 +55,7 @@ TEST(Ensemble, PredictUnivariateRegression) { dataset.AddCovariates(test_dataset.covariates.data(), n, test_dataset.x_cols, test_dataset.row_major); dataset.AddBasis(test_dataset.omega.data(), test_dataset.n, test_dataset.omega_cols, test_dataset.row_major); StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), n); - + // Create a small ensemble int output_dim = 1; int num_trees = 2; @@ -71,10 +67,11 @@ TEST(Ensemble, PredictUnivariateRegression) { StochTree::TreeSplit tree_split = StochTree::TreeSplit(0.5); tree->ExpandNode(0, 0, tree_split, -5., 5.); tree = ensemble.GetTree(1); - tree->ExpandNode(0, 1, tree_split, -2.5, 2.5);; + tree->ExpandNode(0, 1, tree_split, -2.5, 2.5); + ; // Run predict on the supplied covariates and check the result - std::vector result(n*output_dim); + std::vector result(n * output_dim); std::vector expected_pred = {7.3351256, 0.8511415, -1.5396290, 5.7172741, 4.7433491, 4.5919388, 1.0123031, 2.4834167, -6.5187785, -1.4611208}; ensemble.PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), result, 0); for (int i = 0; i < n; i++) { @@ -94,7 +91,7 @@ TEST(Ensemble, PredictMultivariateRegression) { dataset.AddCovariates(test_dataset.covariates.data(), n, test_dataset.x_cols, test_dataset.row_major); dataset.AddBasis(test_dataset.omega.data(), test_dataset.n, test_dataset.omega_cols, test_dataset.row_major); StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), n); - + // Create a small ensemble int output_dim = 2; int num_trees = 2; diff --git a/test/cpp/test_random_effects.cpp b/test/cpp/test_random_effects.cpp index 2908ced5..f703d9e7 100644 --- a/test/cpp/test_random_effects.cpp +++ b/test/cpp/test_random_effects.cpp @@ -8,8 +8,6 @@ #include #include #include -#include -#include TEST(RandomEffects, Setup) { // Load test data @@ -22,7 +20,7 @@ TEST(RandomEffects, Setup) { StochTree::RandomEffectsDataset dataset = StochTree::RandomEffectsDataset(); dataset.AddBasis(test_dataset.rfx_basis.data(), test_dataset.n, test_dataset.rfx_basis_cols, test_dataset.row_major); dataset.AddGroupLabels(test_dataset.rfx_groups); - + // Construct tracker, model state, and container StochTree::RandomEffectsTracker tracker = StochTree::RandomEffectsTracker(test_dataset.rfx_groups); StochTree::MultivariateRegressionRandomEffectsModel model = StochTree::MultivariateRegressionRandomEffectsModel(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups); @@ -30,7 +28,7 @@ TEST(RandomEffects, Setup) { // Check the internal label map of the RandomEffectsTracker std::map label_map = tracker.GetLabelMap(); - std::map expected_label_map {{1, 0}, {2, 1}}; + std::map expected_label_map{{1, 0}, {2, 1}}; ASSERT_EQ(label_map, expected_label_map); } @@ -45,7 +43,7 @@ TEST(RandomEffects, Construction) { StochTree::RandomEffectsDataset dataset = StochTree::RandomEffectsDataset(); dataset.AddBasis(test_dataset.rfx_basis.data(), test_dataset.n, test_dataset.rfx_basis_cols, test_dataset.row_major); dataset.AddGroupLabels(test_dataset.rfx_groups); - + // Construct tracker, model state, and container StochTree::RandomEffectsTracker tracker = StochTree::RandomEffectsTracker(test_dataset.rfx_groups); StochTree::MultivariateRegressionRandomEffectsModel model = StochTree::MultivariateRegressionRandomEffectsModel(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups); @@ -65,7 +63,7 @@ TEST(RandomEffects, Construction) { model.SetGroupParameter(xi0, 0); model.SetGroupParameter(xi1, 1); model.SetGroupParameterCovariance(sigma); - + // Push to the container container.AddSample(model); @@ -83,17 +81,17 @@ TEST(RandomEffects, Construction) { // Check data in the container std::vector alpha_retrieved = container.GetAlpha(); - std::vector alpha_expected {1.5, 2.0}; + std::vector alpha_expected{1.5, 2.0}; for (int i = 0; i < alpha_expected.size(); i++) { ASSERT_EQ(alpha_retrieved[i], alpha_expected[i]); } std::vector xi_retrieved = container.GetXi(); - std::vector xi_expected {2, 4, 1, 3}; + std::vector xi_expected{2, 4, 1, 3}; for (int i = 0; i < xi_expected.size(); i++) { ASSERT_EQ(xi_retrieved[i], xi_expected[i]); } std::vector beta_retrieved = container.GetBeta(); - std::vector beta_expected {3, 6, 2, 6}; + std::vector beta_expected{3, 6, 2, 6}; for (int i = 0; i < beta_expected.size(); i++) { ASSERT_EQ(beta_retrieved[i], beta_expected[i]); } @@ -111,7 +109,7 @@ TEST(RandomEffects, Computation) { StochTree::RandomEffectsDataset dataset = StochTree::RandomEffectsDataset(); dataset.AddBasis(test_dataset.rfx_basis.data(), test_dataset.n, test_dataset.rfx_basis_cols, test_dataset.row_major); dataset.AddGroupLabels(test_dataset.rfx_groups); - + // Construct tracker, model state, and container StochTree::RandomEffectsTracker tracker = StochTree::RandomEffectsTracker(test_dataset.rfx_groups); StochTree::MultivariateRegressionRandomEffectsModel model = StochTree::MultivariateRegressionRandomEffectsModel(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups); @@ -134,7 +132,7 @@ TEST(RandomEffects, Computation) { model.SetGroupParameter(xi2, 2); model.SetGroupParameterCovariance(sigma); double sigma2 = 1.; - + // Compute the posterior mean for the group parameters Eigen::VectorXd xi0_mean = model.GroupParameterMean(dataset, residual, tracker, sigma2, 0); Eigen::VectorXd xi1_mean = model.GroupParameterMean(dataset, residual, tracker, sigma2, 1); @@ -167,7 +165,7 @@ TEST(RandomEffects, Predict) { StochTree::RandomEffectsDataset dataset = StochTree::RandomEffectsDataset(); dataset.AddBasis(test_dataset.rfx_basis.data(), test_dataset.n, test_dataset.rfx_basis_cols, test_dataset.row_major); dataset.AddGroupLabels(test_dataset.rfx_groups); - + // Construct tracker, model state, and container StochTree::RandomEffectsTracker tracker = StochTree::RandomEffectsTracker(test_dataset.rfx_groups); StochTree::MultivariateRegressionRandomEffectsModel model = StochTree::MultivariateRegressionRandomEffectsModel(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups); @@ -187,7 +185,7 @@ TEST(RandomEffects, Predict) { model.SetGroupParameter(xi0, 0); model.SetGroupParameter(xi1, 1); model.SetGroupParameterCovariance(sigma); - + // Push to the container container.AddSample(model); @@ -205,13 +203,31 @@ TEST(RandomEffects, Predict) { // Predict from the container int num_samples = 2; - std::vector output(n*num_samples); + std::vector output(n * num_samples); container.Predict(dataset, label_mapper, output); // Check predictions - std::vector output_expected { - 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, - 2, 6, 2, 6, 2, 6, 2, 6, 2, 6, + std::vector output_expected{ + 3, + 6, + 3, + 6, + 3, + 6, + 3, + 6, + 3, + 6, + 2, + 6, + 2, + 6, + 2, + 6, + 2, + 6, + 2, + 6, }; for (int i = 0; i < output.size(); i++) { ASSERT_EQ(output[i], output_expected[i]); @@ -229,7 +245,7 @@ TEST(RandomEffects, Serialization) { StochTree::RandomEffectsDataset dataset = StochTree::RandomEffectsDataset(); dataset.AddBasis(test_dataset.rfx_basis.data(), test_dataset.n, test_dataset.rfx_basis_cols, test_dataset.row_major); dataset.AddGroupLabels(test_dataset.rfx_groups); - + // Construct tracker, model state, and container StochTree::RandomEffectsTracker tracker = StochTree::RandomEffectsTracker(test_dataset.rfx_groups); StochTree::MultivariateRegressionRandomEffectsModel model = StochTree::MultivariateRegressionRandomEffectsModel(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups); @@ -249,7 +265,7 @@ TEST(RandomEffects, Serialization) { model.SetGroupParameter(xi0, 0); model.SetGroupParameter(xi1, 1); model.SetGroupParameterCovariance(sigma); - + // Push to the container container.AddSample(model); diff --git a/test/cpp/test_sorted_partition_tracker.cpp b/test/cpp/test_sorted_partition_tracker.cpp index 99dbbe5d..752cf33d 100644 --- a/test/cpp/test_sorted_partition_tracker.cpp +++ b/test/cpp/test_sorted_partition_tracker.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include TEST(SortedNodeSampleTracker, BasicOperations) { @@ -52,23 +51,23 @@ TEST(SortedNodeSampleTracker, BasicOperations) { ASSERT_EQ(sample_node_mapper.GetNodeId(7, 0), 2); ASSERT_EQ(sample_node_mapper.GetNodeId(8, 0), 1); ASSERT_EQ(sample_node_mapper.GetNodeId(9, 0), 1); - + // Check that node begin and node end haven't changed for root node, but that the indices have been sifted ASSERT_EQ(sorted_node_sampler_tracker.NodeBegin(0, 0), 0); ASSERT_EQ(sorted_node_sampler_tracker.NodeEnd(0, 0), n); - std::vector expected_result{2,8,9,4,1,6,3,7,0,5}; + std::vector expected_result{2, 8, 9, 4, 1, 6, 3, 7, 0, 5}; ASSERT_EQ(sorted_node_sampler_tracker.NodeIndices(0, 0), expected_result); - + // Check node begin and node end for left node ASSERT_EQ(sorted_node_sampler_tracker.NodeBegin(1, 0), 0); ASSERT_EQ(sorted_node_sampler_tracker.NodeEnd(1, 0), 3); - expected_result = {2,8,9}; + expected_result = {2, 8, 9}; ASSERT_EQ(sorted_node_sampler_tracker.NodeIndices(1, 0), expected_result); - + // Check node begin and node end for right node ASSERT_EQ(sorted_node_sampler_tracker.NodeBegin(2, 0), 3); ASSERT_EQ(sorted_node_sampler_tracker.NodeEnd(2, 0), n); - expected_result = {4,1,6,3,7,0,5}; + expected_result = {4, 1, 6, 3, 7, 0, 5}; ASSERT_EQ(sorted_node_sampler_tracker.NodeIndices(2, 0), expected_result); // Partition right node based on X[,1] <= 0.5 @@ -77,24 +76,24 @@ TEST(SortedNodeSampleTracker, BasicOperations) { // Check that node begin and node end haven't changed for old right node, but that the indices have been sifted ASSERT_EQ(sorted_node_sampler_tracker.NodeBegin(2, 0), 3); ASSERT_EQ(sorted_node_sampler_tracker.NodeEnd(2, 0), n); - expected_result = {1,6,4,3,7,0,5}; + expected_result = {1, 6, 4, 3, 7, 0, 5}; ASSERT_EQ(sorted_node_sampler_tracker.NodeIndices(2, 0), expected_result); // Check same indices for feature 1 ASSERT_EQ(sorted_node_sampler_tracker.NodeBegin(2, 1), 3); ASSERT_EQ(sorted_node_sampler_tracker.NodeEnd(2, 1), n); - expected_result = {6,1,3,0,7,4,5}; + expected_result = {6, 1, 3, 0, 7, 4, 5}; ASSERT_EQ(sorted_node_sampler_tracker.NodeIndices(2, 1), expected_result); - + // Check node begin and node end for new left node ASSERT_EQ(sorted_node_sampler_tracker.NodeBegin(3, 1), 3); ASSERT_EQ(sorted_node_sampler_tracker.NodeEnd(3, 1), 5); - expected_result = {6,1}; + expected_result = {6, 1}; ASSERT_EQ(sorted_node_sampler_tracker.NodeIndices(3, 1), expected_result); - + // Check node begin and node end for new right node ASSERT_EQ(sorted_node_sampler_tracker.NodeBegin(4, 1), 5); ASSERT_EQ(sorted_node_sampler_tracker.NodeEnd(4, 1), n); - expected_result = {3,0,7,4,5}; + expected_result = {3, 0, 7, 4, 5}; ASSERT_EQ(sorted_node_sampler_tracker.NodeIndices(4, 1), expected_result); } diff --git a/test/cpp/test_tree.cpp b/test/cpp/test_tree.cpp index c30302c2..7deecf48 100644 --- a/test/cpp/test_tree.cpp +++ b/test/cpp/test_tree.cpp @@ -7,8 +7,6 @@ #include #include #include -#include -#include TEST(Tree, UnivariateTreeConstruction) { StochTree::Tree tree; @@ -33,10 +31,10 @@ TEST(Tree, UnivariateTreeCopyConstruction) { StochTree::Tree tree_2; StochTree::TreeSplit split; tree_1.Init(1); - + // Check max depth ASSERT_EQ(tree_1.MaxLeafDepth(), 0); - + // Perform two splits split = StochTree::TreeSplit(0.5); tree_1.ExpandNode(0, 0, split, 0., 0.); @@ -46,7 +44,7 @@ TEST(Tree, UnivariateTreeCopyConstruction) { ASSERT_EQ(tree_1.MaxLeafDepth(), 2); ASSERT_EQ(tree_1.NumValidNodes(), 5); ASSERT_EQ(tree_1.NumLeafParents(), 1); - + // Check leaves std::vector leaves = tree_1.GetLeaves(); for (int i = 0; i < leaves.size(); i++) { @@ -57,7 +55,7 @@ TEST(Tree, UnivariateTreeCopyConstruction) { for (int i = 0; i < leaf_parents.size(); i++) { ASSERT_TRUE(tree_1.IsLeafParent(leaf_parents[i])); } - + // Perform another split split = StochTree::TreeSplit(0.6); tree_1.ExpandNode(3, 2, split, 0., 0.); @@ -65,7 +63,7 @@ TEST(Tree, UnivariateTreeCopyConstruction) { ASSERT_EQ(tree_1.NumValidNodes(), 7); ASSERT_EQ(tree_1.NumLeaves(), 4); ASSERT_EQ(tree_1.NumLeafParents(), 1); - + // Check leaves leaves = tree_1.GetLeaves(); for (int i = 0; i < leaves.size(); i++) { @@ -83,7 +81,7 @@ TEST(Tree, UnivariateTreeCopyConstruction) { ASSERT_EQ(tree_1.NumValidNodes(), 5); ASSERT_EQ(tree_1.NumLeaves(), 3); ASSERT_EQ(tree_1.NumLeafParents(), 1); - + // Check leaves leaves = tree_1.GetLeaves(); for (int i = 0; i < leaves.size(); i++) { @@ -106,12 +104,12 @@ TEST(Tree, UnivariateTreeCategoricalSplitConstruction) { StochTree::Tree tree; tree.Init(1); ASSERT_EQ(tree.LeafValue(0), 0.); - tree.ExpandNode(0, 0, std::vector{1,4,6}, 0., 0.); + tree.ExpandNode(0, 0, std::vector{1, 4, 6}, 0., 0.); ASSERT_EQ(tree.NumNodes(), 3); ASSERT_EQ(tree.NodeType(0), StochTree::TreeNodeType::kCategoricalSplitNode); tree.CollapseToLeaf(0, 0.); ASSERT_EQ(tree.NumValidNodes(), 1); - tree.ExpandNode(0, 0, std::vector{2,3,5}, 0., 0.); + tree.ExpandNode(0, 0, std::vector{2, 3, 5}, 0., 0.); ASSERT_EQ(tree.NodeType(0), StochTree::TreeNodeType::kCategoricalSplitNode); ASSERT_EQ(tree.NumValidNodes(), 3); ASSERT_EQ(tree.NumLeaves(), 2); @@ -148,12 +146,12 @@ TEST(Tree, MultivariateTreeCategoricalSplitConstruction) { std::vector leaf_values(tree_dim, 0.); tree.Init(tree_dim); ASSERT_EQ(tree.LeafVector(0), leaf_values); - tree.ExpandNode(0, 0, std::vector{1,4,6}, leaf_values, leaf_values); + tree.ExpandNode(0, 0, std::vector{1, 4, 6}, leaf_values, leaf_values); ASSERT_EQ(tree.NumNodes(), 3); ASSERT_EQ(tree.NodeType(0), StochTree::TreeNodeType::kCategoricalSplitNode); tree.CollapseToLeaf(0, leaf_values); ASSERT_EQ(tree.NumValidNodes(), 1); - tree.ExpandNode(0, 0, std::vector{2,3,5}, leaf_values, leaf_values); + tree.ExpandNode(0, 0, std::vector{2, 3, 5}, leaf_values, leaf_values); ASSERT_EQ(tree.NodeType(0), StochTree::TreeNodeType::kCategoricalSplitNode); ASSERT_EQ(tree.NumValidNodes(), 3); ASSERT_EQ(tree.NumLeaves(), 2); @@ -167,7 +165,7 @@ TEST(Tree, SparseLeafRepresentation) { StochTree::Tree tree; tree.Init(1); tree.ExpandNode(0, 0, 0.5, 0., 0.); - + // Load test data StochTree::TestUtils::TestDataset test_dataset; test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis(); @@ -180,6 +178,6 @@ TEST(Tree, SparseLeafRepresentation) { // Predict leaf indices of each observation in `dataset` std::vector leaf_index_preds(n); tree.PredictLeafIndexInplace(&dataset, leaf_index_preds, 0, 0); - std::vector leaf_index_expected{1,1,0,1,1,1,1,1,0,0}; + std::vector leaf_index_expected{1, 1, 0, 1, 1, 1, 1, 1, 0, 0}; ASSERT_EQ(leaf_index_expected, leaf_index_preds); } diff --git a/test/cpp/test_unsorted_partition_tracker.cpp b/test/cpp/test_unsorted_partition_tracker.cpp index 067e870f..3ad87f1a 100644 --- a/test/cpp/test_unsorted_partition_tracker.cpp +++ b/test/cpp/test_unsorted_partition_tracker.cpp @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include TEST(UnsortedNodeSampleTracker, BasicOperations) { @@ -40,13 +38,13 @@ TEST(UnsortedNodeSampleTracker, BasicOperations) { StochTree::TreeSplit tree_split = StochTree::TreeSplit(0.5); node_sample_tracker.PartitionTreeNode(dataset.GetCovariates(), 0, 0, 1, 2, 0, tree_split); sample_node_mapper.AddSplit(dataset.GetCovariates(), tree_split, 0, 0, 0, 1, 2); - + // Check that node begin and node end haven't changed for root node, but that the indices have been sifted ASSERT_EQ(node_sample_tracker.NodeBegin(0, 0), 0); ASSERT_EQ(node_sample_tracker.NodeEnd(0, 0), n); std::vector expected_result{2, 8, 9, 0, 1, 3, 4, 5, 6, 7}; ASSERT_EQ(node_sample_tracker.TreeNodeIndices(0, 0), expected_result); - + // Check that terminal nodes are updated for for every observation ASSERT_EQ(sample_node_mapper.GetNodeId(0, 0), 2); ASSERT_EQ(sample_node_mapper.GetNodeId(1, 0), 2); @@ -58,13 +56,13 @@ TEST(UnsortedNodeSampleTracker, BasicOperations) { ASSERT_EQ(sample_node_mapper.GetNodeId(7, 0), 2); ASSERT_EQ(sample_node_mapper.GetNodeId(8, 0), 1); ASSERT_EQ(sample_node_mapper.GetNodeId(9, 0), 1); - + // Check node begin and node end for left node ASSERT_EQ(node_sample_tracker.NodeBegin(0, 1), 0); ASSERT_EQ(node_sample_tracker.NodeEnd(0, 1), 3); expected_result = {2, 8, 9}; ASSERT_EQ(node_sample_tracker.TreeNodeIndices(0, 1), expected_result); - + // Check node begin and node end for right node ASSERT_EQ(node_sample_tracker.NodeBegin(0, 2), 3); ASSERT_EQ(node_sample_tracker.NodeEnd(0, 2), n); @@ -81,13 +79,13 @@ TEST(UnsortedNodeSampleTracker, BasicOperations) { ASSERT_EQ(node_sample_tracker.NodeEnd(0, 2), n); expected_result = {1, 6, 0, 3, 4, 5, 7}; ASSERT_EQ(node_sample_tracker.TreeNodeIndices(0, 2), expected_result); - + // Check node begin and node end for new left node ASSERT_EQ(node_sample_tracker.NodeBegin(0, 3), 3); ASSERT_EQ(node_sample_tracker.NodeEnd(0, 3), 5); expected_result = {1, 6}; ASSERT_EQ(node_sample_tracker.TreeNodeIndices(0, 3), expected_result); - + // Check node begin and node end for new right node ASSERT_EQ(node_sample_tracker.NodeBegin(0, 4), 5); ASSERT_EQ(node_sample_tracker.NodeEnd(0, 4), n); diff --git a/test/cpp/testutils.cpp b/test/cpp/testutils.cpp index e6d967a5..4ae1266a 100644 --- a/test/cpp/testutils.cpp +++ b/test/cpp/testutils.cpp @@ -2,18 +2,14 @@ #include #include #include -#include -#include -#include -#include namespace StochTree { -namespace TestUtils{ +namespace TestUtils { TestDataset LoadSmallDatasetUnivariateBasis() { TestDataset output; - + // Data dimensions output.n = 10; output.x_cols = 5; @@ -24,29 +20,29 @@ TestDataset LoadSmallDatasetUnivariateBasis() { output.rfx_basis.resize(output.n, output.rfx_basis_cols); output.rfx_groups.resize(output.n); output.outcome.resize(output.n); - + // Covariates output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269, - 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, - 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, - 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, - 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, - 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, - 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, - 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, - 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, - 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569; - + 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, + 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, + 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, + 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, + 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, + 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, + 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, + 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, + 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569; + // Leaf regression basis output.omega << 0.97801674, 0.34045661, 0.20528387, 0.76230322, 0.63244655, 0.61225851, 0.40492125, 0.33112223, 0.86917047, 0.58444831; - + // Outcome - output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, - 0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414; - + output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, + 0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414; + // Random effects regression basis (i.e. constant, intercept-only RFX model) output.rfx_basis << 1, 1, 1, 1, 1, 1, 1, 1, 1, 1; - + // Random effects group labels for (int i = 0; i < output.n; i++) { if (i % 2 == 0) { @@ -62,7 +58,7 @@ TestDataset LoadSmallDatasetUnivariateBasis() { TestDataset LoadSmallDatasetMultivariateBasis() { TestDataset output; - + // Data dimensions output.n = 10; output.x_cols = 5; @@ -73,38 +69,38 @@ TestDataset LoadSmallDatasetMultivariateBasis() { output.rfx_basis.resize(output.n, output.rfx_basis_cols); output.rfx_groups.resize(output.n); output.outcome.resize(output.n); - + // Covariates output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269, - 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, - 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, - 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, - 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, - 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, - 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, - 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, - 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, - 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569; - + 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, + 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, + 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, + 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, + 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, + 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, + 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, + 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, + 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569; + // Leaf regression basis - output.omega << 0.97801674, 0.3707159, - 0.34045661, 0.1312134, - 0.20528387, 0.5614470, - 0.76230322, 0.2276504, - 0.63244655, 0.9029984, - 0.61225851, 0.7448547, - 0.40492125, 0.2549813, - 0.33112223, 0.5295535, - 0.86917047, 0.5584614, - 0.58444831, 0.2365117; - + output.omega << 0.97801674, 0.3707159, + 0.34045661, 0.1312134, + 0.20528387, 0.5614470, + 0.76230322, 0.2276504, + 0.63244655, 0.9029984, + 0.61225851, 0.7448547, + 0.40492125, 0.2549813, + 0.33112223, 0.5295535, + 0.86917047, 0.5584614, + 0.58444831, 0.2365117; + // Outcome - output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, - 0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414; - + output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, + 0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414; + // Random effects regression basis (i.e. constant, intercept-only RFX model) output.rfx_basis << 1, 1, 1, 1, 1, 1, 1, 1, 1, 1; - + // Random effects group labels for (int i = 0; i < output.n; i++) { if (i % 2 == 0) { @@ -120,7 +116,7 @@ TestDataset LoadSmallDatasetMultivariateBasis() { TestDataset LoadSmallRFXDatasetMultivariateBasis() { TestDataset output; - + // Data dimensions output.n = 10; output.x_cols = 5; @@ -131,49 +127,49 @@ TestDataset LoadSmallRFXDatasetMultivariateBasis() { output.rfx_basis.resize(output.n, output.rfx_basis_cols); output.rfx_groups.resize(output.n); output.outcome.resize(output.n); - + // Covariates output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269, - 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, - 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, - 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, - 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, - 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, - 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, - 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, - 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, - 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569; - + 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, + 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, + 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, + 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, + 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, + 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, + 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, + 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, + 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569; + // Leaf regression basis - output.omega << 0.97801674, 0.3707159, - 0.34045661, 0.1312134, - 0.20528387, 0.5614470, - 0.76230322, 0.2276504, - 0.63244655, 0.9029984, - 0.61225851, 0.7448547, - 0.40492125, 0.2549813, - 0.33112223, 0.5295535, - 0.86917047, 0.5584614, - 0.58444831, 0.2365117; - + output.omega << 0.97801674, 0.3707159, + 0.34045661, 0.1312134, + 0.20528387, 0.5614470, + 0.76230322, 0.2276504, + 0.63244655, 0.9029984, + 0.61225851, 0.7448547, + 0.40492125, 0.2549813, + 0.33112223, 0.5295535, + 0.86917047, 0.5584614, + 0.58444831, 0.2365117; + // Outcome - output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, - 0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414; - + output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, + 0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414; + // Random effects regression basis (i.e. constant, intercept-only RFX model) - output.rfx_basis << 1, 0.3707159, - 1, 0.1312134, - 1, 0.5614470, - 1, 0.2276504, - 1, 0.9029984, - 1, 0.7448547, - 1, 0.2549813, - 1, 0.5295535, - 1, 0.5584614, - 1, 0.2365117; - + output.rfx_basis << 1, 0.3707159, + 1, 0.1312134, + 1, 0.5614470, + 1, 0.2276504, + 1, 0.9029984, + 1, 0.7448547, + 1, 0.2549813, + 1, 0.5295535, + 1, 0.5584614, + 1, 0.2365117; + // Random effects group labels - output.rfx_groups = {1,2,3,1,2,3,1,2,3,1}; + output.rfx_groups = {1, 2, 3, 1, 2, 3, 1, 2, 3, 1}; // for (int i = 0; i < output.n; i++) { // if (i % 2 == 0) { // output.rfx_groups[i] = 1; @@ -188,7 +184,7 @@ TestDataset LoadSmallRFXDatasetMultivariateBasis() { TestDataset LoadMediumDatasetUnivariateBasis() { TestDataset output; - + // Data dimensions output.n = 100; output.x_cols = 5; @@ -199,148 +195,148 @@ TestDataset LoadMediumDatasetUnivariateBasis() { output.rfx_basis.resize(output.n, output.rfx_basis_cols); output.rfx_groups.resize(output.n); output.outcome.resize(output.n); - + // Covariates output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269, - 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, - 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, - 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, - 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, - 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, - 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, - 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, - 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, - 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569, - 0.012786428, 0.46925501, 0.25363201, 0.3429851863, 0.2071495, - 0.887479904, 0.66166194, 0.31100105, 0.2895678403, 0.00117005, - 0.147758652, 0.14108789, 0.0361254, 0.4790630946, 0.47336526, - 0.899947367, 0.03730855, 0.33408769, 0.368503517, 0.30600202, - 0.527616998, 0.22344076, 0.20325828, 0.9296060419, 0.34518043, - 0.947085596, 0.85906392, 0.35535464, 0.529360628, 0.8781696, - 0.716097994, 0.9149628, 0.11689428, 0.1157865208, 0.31602707, - 0.433331308, 0.53848417, 0.34146036, 0.4967994317, 0.12822296, - 0.420861259, 0.28802486, 0.62324752, 0.2045601751, 0.06909585, - 0.275279159, 0.69079999, 0.29498051, 0.0082852058, 0.45247107, - 0.909681016, 0.35067747, 0.66813255, 0.3866910117, 0.65315347, - 0.828031845, 0.74096924, 0.33982958, 0.0009472317, 0.65103292, - 0.261653444, 0.43179244, 0.89632155, 0.8636559783, 0.93461464, - 0.209384357, 0.12561389, 0.69809409, 0.4752417156, 0.34963379, - 0.737655852, 0.42078584, 0.09970929, 0.5218528947, 0.36737846, - 0.975034732, 0.69977514, 0.33918481, 0.5443784453, 0.35411297, - 0.053533786, 0.98021485, 0.71035393, 0.189234901, 0.73372176, - 0.364139644, 0.47595789, 0.24620073, 0.4284725219, 0.46145259, - 0.696115067, 0.18095114, 0.66919045, 0.9517078404, 0.31686943, - 0.920878008, 0.89758374, 0.21445324, 0.5666448742, 0.29554824, - 0.397853079, 0.12019741, 0.10775046, 0.0799620333, 0.20065807, - 0.322087545, 0.68342919, 0.29873607, 0.0044371644, 0.66733723, - 0.661407114, 0.0558764, 0.10688295, 0.067841246, 0.52254161, - 0.593253554, 0.40498486, 0.97342655, 0.1917967587, 0.2078643, - 0.392762915, 0.91608107, 0.98894976, 0.3599016496, 0.70576753, - 0.758995247, 0.19899099, 0.95978035, 0.8000916124, 0.8356055, - 0.105617762, 0.12135206, 0.47523114, 0.3594282658, 0.71053726, - 0.754330984, 0.803395, 0.11297253, 0.5072350584, 0.05109695, - 0.410083859, 0.13842349, 0.3671543, 0.262290115, 0.76582706, - 0.498883172, 0.52094766, 0.23674406, 0.8919167451, 0.26313017, - 0.315790046, 0.57934811, 0.96794023, 0.7292640421, 0.63874656, - 0.969918807, 0.86839672, 0.17867962, 0.797609952, 0.3123159, - 0.291589217, 0.37982099, 0.92081884, 0.3760313739, 0.30599535, - 0.874146047, 0.64472863, 0.74944373, 0.0179410274, 0.06637048, - 0.006168369, 0.36819005, 0.48640614, 0.5182905369, 0.37514676, - 0.018794786, 0.50404546, 0.30706335, 0.239409535, 0.78368968, - 0.218041312, 0.08232156, 0.910968, 0.236348928, 0.08734924, - 0.240712896, 0.81851635, 0.75910757, 0.7666831033, 0.51030368, - 0.32422135, 0.37234399, 0.4268269, 0.0688136201, 0.52522145, - 0.737050103, 0.55333162, 0.35681609, 0.5527229193, 0.45528166, - 0.666105454, 0.44928217, 0.93068357, 0.2682658806, 0.47992145, - 0.072705164, 0.24379538, 0.36250275, 0.2693803106, 0.88583253, - 0.393483048, 0.7180344, 0.88936403, 0.9690254654, 0.41720031, - 0.726532397, 0.15675097, 0.14675637, 0.973136256, 0.86701643, - 0.206543021, 0.70612692, 0.9923119, 0.1270776591, 0.43317344, - 0.392393596, 0.6581254, 0.51121301, 0.8005079071, 0.16056554, - 0.326374607, 0.48817642, 0.68630408, 0.9265561129, 0.48683193, - 0.761818521, 0.71751337, 0.83854992, 0.134206275, 0.25700676, - 0.930924999, 0.37469277, 0.42861545, 0.7379696709, 0.9670993, - 0.601101112, 0.56631699, 0.85690728, 0.0792362478, 0.23640603, - 0.294070227, 0.02818223, 0.83060893, 0.8203584203, 0.17647972, - 0.393978659, 0.88639966, 0.80788018, 0.4202279691, 0.75344798, - 0.381183787, 0.98751161, 0.13933232, 0.5427466533, 0.15809025, - 0.203872876, 0.31032719, 0.53000948, 0.6001499062, 0.43581315, - 0.355075927, 0.10865708, 0.21823445, 0.5707600345, 0.84459087, - 0.415892882, 0.09056941, 0.85957968, 0.9296874236, 0.39317951, - 0.885163931, 0.60617414, 0.22888755, 0.9225545505, 0.41601782, - 0.803631177, 0.63855664, 0.4968153, 0.4970232591, 0.28230652, - 0.755692566, 0.36382158, 0.31492054, 0.9853899847, 0.45864754, - 0.761099141, 0.88094342, 0.82542666, 0.977985516, 0.5416208, - 0.536037115, 0.19298885, 0.67674639, 0.213044832, 0.29409245, - 0.050087478, 0.56597845, 0.22309031, 0.7668617836, 0.02385271, - 0.847882026, 0.86580035, 0.8381724, 0.618777399, 0.4707389, - 0.280194086, 0.95490103, 0.27399251, 0.5894525715, 0.17181438, - 0.261382768, 0.96124295, 0.33737123, 0.3545607659, 0.36367031, - 0.465759262, 0.17167592, 0.87114988, 0.4175856721, 0.16020522, - 0.982323635, 0.30892377, 0.96513595, 0.376671114, 0.9411435, - 0.851789546, 0.42260807, 0.37396782, 0.0759502219, 0.41219659, - 0.23932738, 0.70124641, 0.08544481, 0.8599137105, 0.35298377, - 0.985171556, 0.48493665, 0.92919919, 0.3128095574, 0.84388465, - 0.936608667, 0.70159722, 0.23570122, 0.5124408882, 0.99478731, - 0.328337863, 0.83252833, 0.29078719, 0.7531193637, 0.49378383, - 0.504403078, 0.72845174, 0.12801659, 0.5383322216, 0.12559066, - 0.906952623, 0.36801267, 0.13168735, 0.9791060984, 0.14008791, - 0.454210506, 0.67248289, 0.4041049, 0.234963659, 0.92138674, - 0.499037576, 0.7534805, 0.4168877, 0.6275620307, 0.24189188, - 0.707788941, 0.91990553, 0.56701198, 0.1408275496, 0.80566006, - 0.694437274, 0.69339343, 0.42296251, 0.8271595608, 0.53699966, - 0.447118821, 0.97512181, 0.16431204, 0.3697280197, 0.38753206, - 0.885936489, 0.94468978, 0.48918779, 0.3676202064, 0.06938232, - 0.593980148, 0.28140352, 0.27760537, 0.2819242389, 0.8730862, - 0.04248501, 0.45279893, 0.69760642, 0.0949480394, 0.42568701, - 0.35842742, 0.68098838, 0.82745029, 0.5315801166, 0.31104918, - 0.724621041, 0.28763999, 0.48743089, 0.8648093319, 0.93792148, - 0.961828358, 0.5548953, 0.7250596, 0.249875583, 0.90661302, - 0.251438316, 0.86021024, 0.65037498, 0.209739062, 0.07886205, - 0.699615913, 0.12223695, 0.20393331, 0.6357937951, 0.81502268, - 0.391076967, 0.25143855, 0.16091307, 0.6037441837, 0.50651534, - 0.343597198, 0.82570727, 0.62455707, 0.6284155636, 0.17288776, - 0.451352309, 0.29346835, 0.12641623, 0.1194773833, 0.88849468; - + 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, + 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, + 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, + 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, + 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, + 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, + 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, + 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, + 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569, + 0.012786428, 0.46925501, 0.25363201, 0.3429851863, 0.2071495, + 0.887479904, 0.66166194, 0.31100105, 0.2895678403, 0.00117005, + 0.147758652, 0.14108789, 0.0361254, 0.4790630946, 0.47336526, + 0.899947367, 0.03730855, 0.33408769, 0.368503517, 0.30600202, + 0.527616998, 0.22344076, 0.20325828, 0.9296060419, 0.34518043, + 0.947085596, 0.85906392, 0.35535464, 0.529360628, 0.8781696, + 0.716097994, 0.9149628, 0.11689428, 0.1157865208, 0.31602707, + 0.433331308, 0.53848417, 0.34146036, 0.4967994317, 0.12822296, + 0.420861259, 0.28802486, 0.62324752, 0.2045601751, 0.06909585, + 0.275279159, 0.69079999, 0.29498051, 0.0082852058, 0.45247107, + 0.909681016, 0.35067747, 0.66813255, 0.3866910117, 0.65315347, + 0.828031845, 0.74096924, 0.33982958, 0.0009472317, 0.65103292, + 0.261653444, 0.43179244, 0.89632155, 0.8636559783, 0.93461464, + 0.209384357, 0.12561389, 0.69809409, 0.4752417156, 0.34963379, + 0.737655852, 0.42078584, 0.09970929, 0.5218528947, 0.36737846, + 0.975034732, 0.69977514, 0.33918481, 0.5443784453, 0.35411297, + 0.053533786, 0.98021485, 0.71035393, 0.189234901, 0.73372176, + 0.364139644, 0.47595789, 0.24620073, 0.4284725219, 0.46145259, + 0.696115067, 0.18095114, 0.66919045, 0.9517078404, 0.31686943, + 0.920878008, 0.89758374, 0.21445324, 0.5666448742, 0.29554824, + 0.397853079, 0.12019741, 0.10775046, 0.0799620333, 0.20065807, + 0.322087545, 0.68342919, 0.29873607, 0.0044371644, 0.66733723, + 0.661407114, 0.0558764, 0.10688295, 0.067841246, 0.52254161, + 0.593253554, 0.40498486, 0.97342655, 0.1917967587, 0.2078643, + 0.392762915, 0.91608107, 0.98894976, 0.3599016496, 0.70576753, + 0.758995247, 0.19899099, 0.95978035, 0.8000916124, 0.8356055, + 0.105617762, 0.12135206, 0.47523114, 0.3594282658, 0.71053726, + 0.754330984, 0.803395, 0.11297253, 0.5072350584, 0.05109695, + 0.410083859, 0.13842349, 0.3671543, 0.262290115, 0.76582706, + 0.498883172, 0.52094766, 0.23674406, 0.8919167451, 0.26313017, + 0.315790046, 0.57934811, 0.96794023, 0.7292640421, 0.63874656, + 0.969918807, 0.86839672, 0.17867962, 0.797609952, 0.3123159, + 0.291589217, 0.37982099, 0.92081884, 0.3760313739, 0.30599535, + 0.874146047, 0.64472863, 0.74944373, 0.0179410274, 0.06637048, + 0.006168369, 0.36819005, 0.48640614, 0.5182905369, 0.37514676, + 0.018794786, 0.50404546, 0.30706335, 0.239409535, 0.78368968, + 0.218041312, 0.08232156, 0.910968, 0.236348928, 0.08734924, + 0.240712896, 0.81851635, 0.75910757, 0.7666831033, 0.51030368, + 0.32422135, 0.37234399, 0.4268269, 0.0688136201, 0.52522145, + 0.737050103, 0.55333162, 0.35681609, 0.5527229193, 0.45528166, + 0.666105454, 0.44928217, 0.93068357, 0.2682658806, 0.47992145, + 0.072705164, 0.24379538, 0.36250275, 0.2693803106, 0.88583253, + 0.393483048, 0.7180344, 0.88936403, 0.9690254654, 0.41720031, + 0.726532397, 0.15675097, 0.14675637, 0.973136256, 0.86701643, + 0.206543021, 0.70612692, 0.9923119, 0.1270776591, 0.43317344, + 0.392393596, 0.6581254, 0.51121301, 0.8005079071, 0.16056554, + 0.326374607, 0.48817642, 0.68630408, 0.9265561129, 0.48683193, + 0.761818521, 0.71751337, 0.83854992, 0.134206275, 0.25700676, + 0.930924999, 0.37469277, 0.42861545, 0.7379696709, 0.9670993, + 0.601101112, 0.56631699, 0.85690728, 0.0792362478, 0.23640603, + 0.294070227, 0.02818223, 0.83060893, 0.8203584203, 0.17647972, + 0.393978659, 0.88639966, 0.80788018, 0.4202279691, 0.75344798, + 0.381183787, 0.98751161, 0.13933232, 0.5427466533, 0.15809025, + 0.203872876, 0.31032719, 0.53000948, 0.6001499062, 0.43581315, + 0.355075927, 0.10865708, 0.21823445, 0.5707600345, 0.84459087, + 0.415892882, 0.09056941, 0.85957968, 0.9296874236, 0.39317951, + 0.885163931, 0.60617414, 0.22888755, 0.9225545505, 0.41601782, + 0.803631177, 0.63855664, 0.4968153, 0.4970232591, 0.28230652, + 0.755692566, 0.36382158, 0.31492054, 0.9853899847, 0.45864754, + 0.761099141, 0.88094342, 0.82542666, 0.977985516, 0.5416208, + 0.536037115, 0.19298885, 0.67674639, 0.213044832, 0.29409245, + 0.050087478, 0.56597845, 0.22309031, 0.7668617836, 0.02385271, + 0.847882026, 0.86580035, 0.8381724, 0.618777399, 0.4707389, + 0.280194086, 0.95490103, 0.27399251, 0.5894525715, 0.17181438, + 0.261382768, 0.96124295, 0.33737123, 0.3545607659, 0.36367031, + 0.465759262, 0.17167592, 0.87114988, 0.4175856721, 0.16020522, + 0.982323635, 0.30892377, 0.96513595, 0.376671114, 0.9411435, + 0.851789546, 0.42260807, 0.37396782, 0.0759502219, 0.41219659, + 0.23932738, 0.70124641, 0.08544481, 0.8599137105, 0.35298377, + 0.985171556, 0.48493665, 0.92919919, 0.3128095574, 0.84388465, + 0.936608667, 0.70159722, 0.23570122, 0.5124408882, 0.99478731, + 0.328337863, 0.83252833, 0.29078719, 0.7531193637, 0.49378383, + 0.504403078, 0.72845174, 0.12801659, 0.5383322216, 0.12559066, + 0.906952623, 0.36801267, 0.13168735, 0.9791060984, 0.14008791, + 0.454210506, 0.67248289, 0.4041049, 0.234963659, 0.92138674, + 0.499037576, 0.7534805, 0.4168877, 0.6275620307, 0.24189188, + 0.707788941, 0.91990553, 0.56701198, 0.1408275496, 0.80566006, + 0.694437274, 0.69339343, 0.42296251, 0.8271595608, 0.53699966, + 0.447118821, 0.97512181, 0.16431204, 0.3697280197, 0.38753206, + 0.885936489, 0.94468978, 0.48918779, 0.3676202064, 0.06938232, + 0.593980148, 0.28140352, 0.27760537, 0.2819242389, 0.8730862, + 0.04248501, 0.45279893, 0.69760642, 0.0949480394, 0.42568701, + 0.35842742, 0.68098838, 0.82745029, 0.5315801166, 0.31104918, + 0.724621041, 0.28763999, 0.48743089, 0.8648093319, 0.93792148, + 0.961828358, 0.5548953, 0.7250596, 0.249875583, 0.90661302, + 0.251438316, 0.86021024, 0.65037498, 0.209739062, 0.07886205, + 0.699615913, 0.12223695, 0.20393331, 0.6357937951, 0.81502268, + 0.391076967, 0.25143855, 0.16091307, 0.6037441837, 0.50651534, + 0.343597198, 0.82570727, 0.62455707, 0.6284155636, 0.17288776, + 0.451352309, 0.29346835, 0.12641623, 0.1194773833, 0.88849468; + // Leaf regression basis output.omega << 0.97801674, 0.34045661, 0.20528387, 0.76230322, 0.63244655, 0.61225851, 0.40492125, 0.33112223, - 0.86917047, 0.58444831, 0.33316433, 0.62217709, 0.96820668, 0.20778425, 0.23764591, 0.94193115, - 0.03869153, 0.60847765, 0.51535811, 0.81554404, 0.78515289, 0.23337815, 0.16730957, 0.02168331, - 0.08699654, 0.34067049, 0.93141264, 0.03679176, 0.4364772, 0.2644173, 0.23717182, 0.59084776, - 0.63438143, 0.57132227, 0.17568721, 0.15552373, 0.8625478, 0.02466334, 0.47269628, 0.97782225, - 0.90593388, 0.82272111, 0.67374992, 0.47619752, 0.5276532, 0.75182919, 0.09559243, 0.5126907, - 0.45892102, 0.11357212, 0.77861167, 0.78424907, 0.84693988, 0.38814934, 0.01010333, 0.10064384, - 0.68664865, 0.1264298, 0.14314708, 0.62679815, 0.71101772, 0.43504811, 0.8868721, 0.95098048, - 0.38291537, 0.71337451, 0.12109764, 0.68943347, 0.89878588, 0.67524475, 0.95549402, 0.58758459, - 0.68558459, 0.16794963, 0.23680754, 0.40289479, 0.98291039, 0.87276966, 0.76995475, 0.55282963, - 0.12448394, 0.5479543, 0.8718802, 0.14515363, 0.71311006, 0.39196408, 0.94504373, 0.44020353, - 0.24090674, 0.52675625, 0.86674581, 0.90576332, 0.09167602, 0.74795585, 0.26901811, 0.544173, - 0.03336554, 0.8314331, 0.27185696, 0.83434459; - + 0.86917047, 0.58444831, 0.33316433, 0.62217709, 0.96820668, 0.20778425, 0.23764591, 0.94193115, + 0.03869153, 0.60847765, 0.51535811, 0.81554404, 0.78515289, 0.23337815, 0.16730957, 0.02168331, + 0.08699654, 0.34067049, 0.93141264, 0.03679176, 0.4364772, 0.2644173, 0.23717182, 0.59084776, + 0.63438143, 0.57132227, 0.17568721, 0.15552373, 0.8625478, 0.02466334, 0.47269628, 0.97782225, + 0.90593388, 0.82272111, 0.67374992, 0.47619752, 0.5276532, 0.75182919, 0.09559243, 0.5126907, + 0.45892102, 0.11357212, 0.77861167, 0.78424907, 0.84693988, 0.38814934, 0.01010333, 0.10064384, + 0.68664865, 0.1264298, 0.14314708, 0.62679815, 0.71101772, 0.43504811, 0.8868721, 0.95098048, + 0.38291537, 0.71337451, 0.12109764, 0.68943347, 0.89878588, 0.67524475, 0.95549402, 0.58758459, + 0.68558459, 0.16794963, 0.23680754, 0.40289479, 0.98291039, 0.87276966, 0.76995475, 0.55282963, + 0.12448394, 0.5479543, 0.8718802, 0.14515363, 0.71311006, 0.39196408, 0.94504373, 0.44020353, + 0.24090674, 0.52675625, 0.86674581, 0.90576332, 0.09167602, 0.74795585, 0.26901811, 0.544173, + 0.03336554, 0.8314331, 0.27185696, 0.83434459; + // Outcome output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, 0.347249942, 0.546179903, - 1.164750138, 3.389946886, -0.605464414, 1.271432631, 2.203609096, 2.192327323, 0.746140817, - 3.009233058, -0.292800298, 1.752730639, 1.824961588, 2.055603702, -0.153889672, 0.248010541, - 1.099472562, 0.822333874, 1.291797503, 0.877720106, 2.365239601, 0.685716301, 1.445624363, - 1.342180906, 0.148136818, -1.157010472, 2.186988614, 1.523371203, 1.740153725, 0.73351857, - 0.449967161, 1.25200968, 1.155083428, 1.580760814, 3.025557265, 1.488059405, -0.069025021, - 1.100181892, 1.014150762, 0.418207324, 3.210834777, 1.658875834, 2.215173806, 1.351802193, - 1.33331705, 2.357354695, -1.449598055, 1.042660314, 0.404779346, 1.35048031, -0.58922199, - -0.281044393, 0.128478258, 0.006620112, 1.237840372, 1.0999817, 2.245489523, 2.114281687, - 1.337789336, 0.668884629, 2.275744698, 1.483665856, 0.577564239, -0.557180209, 3.810578895, - 0.946494502, 1.464014296, 0.793749131, 2.735140925, 2.037714409, 1.530792369, 1.857142205, - 1.015348805, -0.91839562, 1.924546112, -0.218826033, 1.761318971, 0.928338732, 1.109589807, - 2.165307398, 2.258640565, 1.147428989, 0.332872857, 0.373646084, 0.520770108, 1.857996323, - -1.971537882, 0.962010578, 1.552073631, 0.459464684, -0.149159276, 0.203079262, -0.453721958, 2.152977755, 0.948865461; - + 1.164750138, 3.389946886, -0.605464414, 1.271432631, 2.203609096, 2.192327323, 0.746140817, + 3.009233058, -0.292800298, 1.752730639, 1.824961588, 2.055603702, -0.153889672, 0.248010541, + 1.099472562, 0.822333874, 1.291797503, 0.877720106, 2.365239601, 0.685716301, 1.445624363, + 1.342180906, 0.148136818, -1.157010472, 2.186988614, 1.523371203, 1.740153725, 0.73351857, + 0.449967161, 1.25200968, 1.155083428, 1.580760814, 3.025557265, 1.488059405, -0.069025021, + 1.100181892, 1.014150762, 0.418207324, 3.210834777, 1.658875834, 2.215173806, 1.351802193, + 1.33331705, 2.357354695, -1.449598055, 1.042660314, 0.404779346, 1.35048031, -0.58922199, + -0.281044393, 0.128478258, 0.006620112, 1.237840372, 1.0999817, 2.245489523, 2.114281687, + 1.337789336, 0.668884629, 2.275744698, 1.483665856, 0.577564239, -0.557180209, 3.810578895, + 0.946494502, 1.464014296, 0.793749131, 2.735140925, 2.037714409, 1.530792369, 1.857142205, + 1.015348805, -0.91839562, 1.924546112, -0.218826033, 1.761318971, 0.928338732, 1.109589807, + 2.165307398, 2.258640565, 1.147428989, 0.332872857, 0.373646084, 0.520770108, 1.857996323, + -1.971537882, 0.962010578, 1.552073631, 0.459464684, -0.149159276, 0.203079262, -0.453721958, 2.152977755, 0.948865461; + // Random effects regression basis (i.e. constant, intercept-only RFX model) - output.rfx_basis << 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1; - + output.rfx_basis << 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1; + // Random effects group labels for (int i = 0; i < output.n; i++) { if (i % 2 == 0) { @@ -354,6 +350,6 @@ TestDataset LoadMediumDatasetUnivariateBasis() { return output; } -} +} // namespace TestUtils } // namespace StochTree diff --git a/test/cpp/testutils.h b/test/cpp/testutils.h index 09b43133..1b10b441 100644 --- a/test/cpp/testutils.h +++ b/test/cpp/testutils.h @@ -39,8 +39,8 @@ TestDataset LoadSmallRFXDatasetMultivariateBasis(); /*! Creates a modest dataset (100 observations) */ TestDataset LoadMediumDatasetUnivariateBasis(); -} // namespace TestUtils +} // namespace TestUtils -} // namespace StochTree +} // namespace StochTree #endif // STOCHTREE_TESTUTILS_H_ From c6fd61ba80d61944be0d12c4e40618a86c5c523d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 18:45:01 -0400 Subject: [PATCH 19/64] Refactored Eigen out of debug programs --- debug/bart_debug.cpp | 60 ++++++++++++++++------------- debug/bcf_debug.cpp | 91 ++++++++++++++++++++++++-------------------- 2 files changed, 84 insertions(+), 67 deletions(-) diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 5f3b0bec..9a62a012 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -25,7 +25,6 @@ #include #include -#include #include #include #include @@ -38,14 +37,14 @@ static constexpr double kPi = 3.14159265358979323846; // ---- Data ------------------------------------------------------------ struct RegressionDataset { - Eigen::Matrix X; - Eigen::VectorXd y; + std::vector X; + std::vector y; }; struct ProbitDataset { - Eigen::Matrix X; - Eigen::VectorXd y; - Eigen::VectorXd Z; + std::vector X; + std::vector y; + std::vector Z; }; // DGP: y ~ sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) @@ -53,13 +52,13 @@ RegressionDataset generate_constant_leaf_regression_data(int n, int p, std::mt19 std::uniform_real_distribution unif(0.0, 1.0); std::normal_distribution normal(0.0, 1.0); RegressionDataset d; - d.X.resize(n, p); + d.X.resize(n * p); d.y.resize(n); for (int i = 0; i < n; i++) for (int j = 0; j < p; j++) - d.X(i, j) = unif(rng); + d.X[j * n + i] = unif(rng); for (int i = 0; i < n; i++) - d.y(i) = std::sin(2.0 * kPi * d.X(i, 0)) + 0.5 * d.X(i, 1) - 1.5 * d.X(i, 2) + normal(rng); + d.y[i] = std::sin(2.0 * kPi * d.X[i]) + 0.5 * d.X[1 * n + i] - 1.5 * d.X[2 * n + i] + normal(rng); return d; } @@ -70,17 +69,16 @@ RegressionDataset generate_constant_leaf_regression_data(int n, int p, std::mt19 ProbitDataset generate_probit_data(int n, int p, std::mt19937& rng) { std::uniform_real_distribution unif(0.0, 1.0); std::normal_distribution normal(0.0, 1.0); - Eigen::VectorXd Z; ProbitDataset d; - d.X.resize(n, p); + d.X.resize(n * p); d.y.resize(n); d.Z.resize(n); for (int i = 0; i < n; i++) for (int j = 0; j < p; j++) - d.X(i, j) = unif(rng); + d.X[j * n + i] = unif(rng); for (int i = 0; i < n; i++) { - d.Z(i) = std::sin(2.0 * kPi * d.X(i, 0)) + 0.5 * d.X(i, 1) - 1.5 * d.X(i, 2) + normal(rng); - d.y(i) = (d.Z(i) > 0) ? 1.0 : 0.0; + d.Z[i] = std::sin(2.0 * kPi * d.X[i]) + 0.5 * d.X[1 * n + i] - 1.5 * d.X[2 * n + i] + normal(rng); + d.y[i] = (d.Z[i] > 0) ? 1.0 : 0.0; } return d; } @@ -184,13 +182,20 @@ void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, int nu // Generate data RegressionDataset data = generate_constant_leaf_regression_data(n, p, rng); - double y_bar = data.y.mean(); - double y_std = std::sqrt((data.y.array() - y_bar).square().sum() / (data.y.size() - 1)); - Eigen::VectorXd resid_vec = (data.y.array() - y_bar) / y_std; // standardize + double y_bar = std::accumulate(data.y.begin(), data.y.end(), 0.0) / data.y.size(); + double y_std = 0; + for (int i = 0; i < n; i++) { + y_std += (data.y[i] - y_bar) * (data.y[i] - y_bar); + } + y_std = std::sqrt(y_std / n); + std::vector resid_vec(data.y.size()); + for (std::size_t i = 0; i < data.y.size(); i++) { + resid_vec[i] = (data.y[i] - y_bar) / y_std; + } // Initialize dataset and residual vector for sampler StochTree::ForestDataset dataset; - dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/false); StochTree::ColumnVector residual(resid_vec.data(), n); // Initialize global error variance model @@ -205,7 +210,7 @@ void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, int nu // Generate test data and build test dataset RegressionDataset test_data = generate_constant_leaf_regression_data(n_test, p, rng); StochTree::ForestDataset test_dataset; - test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); + test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/false); // Lambda function for reporting test-set RMSE and last draw of global error variance model auto report = [&](const std::vector& preds, double global_variance) { @@ -214,7 +219,7 @@ void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, int nu double mu_hat = 0.0; for (int j = 0; j < num_mcmc; j++) mu_hat += preds[static_cast(j * n_test + i)] / num_mcmc; - double err = (mu_hat * y_std + y_bar) - test_data.y(i); + double err = (mu_hat * y_std + y_bar) - test_data.y[i]; rmse_sum += err * err; } std::cout << "\nScenario 0 (Homoskedastic BART):\n" @@ -242,13 +247,16 @@ void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, int nu // Generate data ProbitDataset data = generate_probit_data(n, p, rng); - double y_bar = StochTree::norm_cdf(data.y.mean()); - Eigen::VectorXd y_vec = data.y.array(); - Eigen::VectorXd Z_vec = (data.y.array() - y_bar); + double y_bar = std::accumulate(data.y.begin(), data.y.end(), 0.0) / data.y.size(); + std::vector y_vec = data.y; + std::vector Z_vec(n); + for (int i = 0; i < n; i++) { + Z_vec[i] = data.y[i] - y_bar; + } // Initialize dataset and residual vector for sampler StochTree::ForestDataset dataset; - dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/false); StochTree::ColumnVector residual(Z_vec.data(), n); // Lambda function for probit data augmentation sampling step (after each forest sample) @@ -260,7 +268,7 @@ void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, int nu // Generate test data and build test dataset ProbitDataset test_data = generate_probit_data(n_test, p, rng); StochTree::ForestDataset test_dataset; - test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); + test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/false); // Lambda function for reporting test-set RMSE auto report = [&](const std::vector& preds, double global_variance) { @@ -269,7 +277,7 @@ void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, int nu double mu_hat = 0.0; for (int j = 0; j < num_mcmc; j++) mu_hat += preds[static_cast(j * n_test + i)] / num_mcmc; - double err = (mu_hat + y_bar) - test_data.Z(i); + double err = (mu_hat + y_bar) - test_data.Z[i]; rmse_sum += err * err; } std::cout << "\nScenario 1 (Probit BART):\n" diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index affcf66c..ff91552c 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -30,7 +30,6 @@ #include #include -#include #include #include #include @@ -43,20 +42,20 @@ static constexpr double kPi = 3.14159265358979323846; // ---- Data ------------------------------------------------------------ struct SimpleBCFDataset { - Eigen::Matrix X; - Eigen::VectorXd y; - Eigen::VectorXd z; - Eigen::VectorXd mu_true; - Eigen::VectorXd tau_true; + std::vector X; + std::vector y; + std::vector z; + std::vector mu_true; + std::vector tau_true; }; struct ProbitBCFDataset { - Eigen::Matrix X; - Eigen::VectorXd y; - Eigen::VectorXd latent_outcome; - Eigen::VectorXd z; - Eigen::VectorXd mu_true; - Eigen::VectorXd tau_true; + std::vector X; + std::vector y; + std::vector latent_outcome; + std::vector z; + std::vector mu_true; + std::vector tau_true; }; SimpleBCFDataset generate_simple_bcf_data(int n, int p, std::mt19937& rng) { @@ -65,7 +64,7 @@ SimpleBCFDataset generate_simple_bcf_data(int n, int p, std::mt19937& rng) { std::bernoulli_distribution bern(0.5); SimpleBCFDataset d; - d.X.resize(n, p); + d.X.resize(n * p); d.y.resize(n); d.z.resize(n); d.mu_true.resize(n); @@ -73,13 +72,13 @@ SimpleBCFDataset generate_simple_bcf_data(int n, int p, std::mt19937& rng) { for (int i = 0; i < n; i++) for (int j = 0; j < p; j++) - d.X(i, j) = unif(rng); + d.X[j * n + i] = unif(rng); for (int i = 0; i < n; i++) { - d.z(i) = bern(rng) ? 1.0 : 0.0; - d.mu_true(i) = 2.0 * std::sin(kPi * d.X(i, 0)) + 0.5 * d.X(i, 1); - d.tau_true(i) = 1.0 + d.X(i, 2); - d.y(i) = d.mu_true(i) + d.tau_true(i) * d.z(i) + 0.5 * normal(rng); + d.z[i] = bern(rng) ? 1.0 : 0.0; + d.mu_true[i] = 2.0 * std::sin(kPi * d.X[i]) + 0.5 * d.X[1 * n + i]; + d.tau_true[i] = 1.0 + d.X[2 * n + i]; + d.y[i] = d.mu_true[i] + d.tau_true[i] * d.z[i] + 0.5 * normal(rng); } return d; } @@ -90,7 +89,7 @@ ProbitBCFDataset generate_probit_bcf_data(int n, int p, std::mt19937& rng) { std::bernoulli_distribution bern(0.5); ProbitBCFDataset d; - d.X.resize(n, p); + d.X.resize(n * p); d.y.resize(n); d.z.resize(n); d.mu_true.resize(n); @@ -99,14 +98,14 @@ ProbitBCFDataset generate_probit_bcf_data(int n, int p, std::mt19937& rng) { for (int i = 0; i < n; i++) for (int j = 0; j < p; j++) - d.X(i, j) = unif(rng); + d.X[j * n + i] = unif(rng); for (int i = 0; i < n; i++) { - d.z(i) = bern(rng) ? 1.0 : 0.0; - d.mu_true(i) = 2.0 * std::sin(kPi * d.X(i, 0)) + 0.5 * d.X(i, 1); - d.tau_true(i) = 1.0 + d.X(i, 2); - d.latent_outcome(i) = d.mu_true(i) + d.tau_true(i) * d.z(i) + normal(rng); - d.y(i) = (d.latent_outcome(i) > 0.0) ? 1.0 : 0.0; + d.z[i] = bern(rng) ? 1.0 : 0.0; + d.mu_true[i] = 2.0 * std::sin(kPi * d.X[i]) + 0.5 * d.X[1 * n + i]; + d.tau_true[i] = 1.0 + d.X[2 * n + i]; + d.latent_outcome[i] = d.mu_true[i] + d.tau_true[i] * d.z[i] + normal(rng); + d.y[i] = (d.latent_outcome[i] > 0.0) ? 1.0 : 0.0; } return d; } @@ -258,13 +257,20 @@ void run_scenario_0(int n, int n_test, int p, int num_trees_mu, int num_trees_ta // Generate data and standardize outcome SimpleBCFDataset data = generate_simple_bcf_data(n, p, rng); - double y_bar = data.y.mean(); - double y_std = std::sqrt((data.y.array() - y_bar).square().mean()); - Eigen::VectorXd resid_vec = (data.y.array() - y_bar) / y_std; // standardize + double y_bar = std::accumulate(data.y.begin(), data.y.end(), 0.0) / data.y.size(); + double y_std = 0; + for (int i = 0; i < n; i++) { + y_std += (data.y[i] - y_bar) * (data.y[i] - y_bar); + } + y_std = std::sqrt(y_std / n); + std::vector resid_vec(n); + for (int i = 0; i < n; i++) { + resid_vec[i] = (data.y[i] - y_bar) / y_std; + } // Shared dataset: only tau forest uses the Z basis for leaf regression StochTree::ForestDataset dataset; - dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); + dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/false); dataset.AddBasis(data.z.data(), n, /*num_col=*/1, /*row_major=*/false); // Shared residual @@ -284,7 +290,7 @@ void run_scenario_0(int n, int n_test, int p, int num_trees_mu, int num_trees_ta // Test dataset: covariates + actual treatment z (for y prediction) StochTree::ForestDataset test_dataset; - test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); + test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/false); test_dataset.AddBasis(test_data.z.data(), n_test, /*num_col=*/1, /*row_major=*/false); // Lambda function for reporting mu/tau RMSE and last draw of global error variance @@ -296,16 +302,16 @@ void run_scenario_0(int n, int n_test, int p, int num_trees_mu, int num_trees_ta double mu_hat = 0.0; for (int j = 0; j < num_mcmc; j++) mu_hat += mu_preds[static_cast(j * n_test + i)] / num_mcmc; - mu_rmse_sum += (mu_hat * y_std + y_bar - test_data.mu_true(i)) * (mu_hat * y_std + y_bar - test_data.mu_true(i)); + mu_rmse_sum += (mu_hat * y_std + y_bar - test_data.mu_true[i]) * (mu_hat * y_std + y_bar - test_data.mu_true[i]); // tau_preds from test_dataset_cate (z=1 basis) => raw CATE estimates double cate_hat = 0.0; for (int j = 0; j < num_mcmc; j++) cate_hat += tau_preds[static_cast(j * n_test + i)] / num_mcmc; - tau_rmse_sum += (cate_hat * y_std - test_data.tau_true(i)) * (cate_hat * y_std - test_data.tau_true(i)); + tau_rmse_sum += (cate_hat * y_std - test_data.tau_true[i]) * (cate_hat * y_std - test_data.tau_true[i]); - double y_hat = mu_hat * y_std + y_bar + cate_hat * test_data.z(i) * y_std; - y_rmse_sum += (y_hat - test_data.y(i)) * (y_hat - test_data.y(i)); + double y_hat = mu_hat * y_std + y_bar + cate_hat * test_data.z[i] * y_std; + y_rmse_sum += (y_hat - test_data.y[i]) * (y_hat - test_data.y[i]); } std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" @@ -336,9 +342,12 @@ void run_scenario_1(int n, int n_test, int p, int num_trees_mu, int num_trees_ta // Generate data and standardize outcome ProbitBCFDataset data = generate_probit_bcf_data(n, p, rng); - double y_bar = StochTree::norm_cdf(data.y.mean()); - Eigen::VectorXd y_vec = data.y.array(); - Eigen::VectorXd Z_vec = (data.y.array() - y_bar); + double y_bar = std::accumulate(data.y.begin(), data.y.end(), 0.0) / data.y.size(); + std::vector Z_vec(n); + for (int i = 0; i < n; i++) { + Z_vec[i] = data.y[i] - y_bar; + } + std::vector y_vec = data.y; // Shared dataset: only tau forest uses the Z basis for leaf regression StochTree::ForestDataset dataset; @@ -375,16 +384,16 @@ void run_scenario_1(int n, int n_test, int p, int num_trees_mu, int num_trees_ta double mu_hat = 0.0; for (int j = 0; j < num_mcmc; j++) mu_hat += mu_preds[static_cast(j * n_test + i)] / num_mcmc; - mu_rmse_sum += (mu_hat + y_bar - test_data.mu_true(i)) * (mu_hat + y_bar - test_data.mu_true(i)); + mu_rmse_sum += (mu_hat + y_bar - test_data.mu_true[i]) * (mu_hat + y_bar - test_data.mu_true[i]); // tau_preds from test_dataset_cate (z=1 basis) => raw CATE estimates double cate_hat = 0.0; for (int j = 0; j < num_mcmc; j++) cate_hat += tau_preds[static_cast(j * n_test + i)] / num_mcmc; - tau_rmse_sum += (cate_hat - test_data.tau_true(i)) * (cate_hat - test_data.tau_true(i)); + tau_rmse_sum += (cate_hat - test_data.tau_true[i]) * (cate_hat - test_data.tau_true[i]); - double y_hat = mu_hat + y_bar + cate_hat * test_data.z(i); - y_rmse_sum += (y_hat - test_data.latent_outcome(i)) * (y_hat - test_data.latent_outcome(i)); + double y_hat = mu_hat + y_bar + cate_hat * test_data.z[i]; + y_rmse_sum += (y_hat - test_data.latent_outcome[i]) * (y_hat - test_data.latent_outcome[i]); } std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" From f91b765dac9312e2d62b8ff357a79f42ccef0d20 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 18:45:13 -0400 Subject: [PATCH 20/64] Added initial BART data structures --- include/stochtree/bart.h | 113 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 include/stochtree/bart.h diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h new file mode 100644 index 00000000..f86eb5ce --- /dev/null +++ b/include/stochtree/bart.h @@ -0,0 +1,113 @@ +/*! + * Copyright (c) 2026 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_BART_H_ +#define STOCHTREE_BART_H_ + +#include +#include +#include "stochtree/container.h" + +namespace StochTree { + +struct BARTData { + // Train set covariates + const double* X_train; + int n_train = 0; + int p = 0; + + // Test set covariates + const double* X_test; + int n_test = 0; + + // Train set outcome + const double* y_train; + + // Basis for leaf regression + const double* basis_train; + const double* basis_test; + int basis_dim = 0; + + // Observation weights + const double* obs_weights_train; + const double* obs_weights_test; + + // Random effects + const int* rfx_group_ids_train; + const int* rfx_group_ids_test; + const double* rfx_basis_train; + const double* rfx_basis_test; + int rfx_num_groups = 0; + int rfx_basis_dim = 0; + + // Feature types encoded as integers (e.g. 0 = continuous, 1 = categorical, etc.) + const int* feature_types; +}; + +struct BARTConfig { + // High level parameters + bool standardize_outcome = true; // whether to standardize the outcome before fitting and unstandardize predictions after + int num_threads = 1; // number of threads to use for sampling + + // Global error variance parameters + double a_sigma2_global = 0.0; // shape parameter for inverse gamma prior on global error variance + double b_sigma2_global = 0.0; // scale parameter for inverse gamma prior on global error variance + double sigma2_global_init = 1.0; // initial value for global error variance + + // Mean forest parameters + int num_trees_mean = 200; // number of trees in the mean forest + double alpha_mean = 0.95; // alpha parameter for mean forest tree prior + double beta_mean = 2.0; // beta parameter for mean forest tree prior + int min_samples_leaf_mean = 5; // minimum number of samples per leaf for mean forest + bool leaf_constant_mean = true; // whether to use constant leaf model for mean forest + bool exponentiated_leaf_mean = false; // whether to exponentiate leaf predictions for mean forest + int num_features_subsample_mean = 0; // number of features to subsample for each mean forest split (0 means no subsampling) + double a_sigma2_mean = 3.0; // shape parameter for inverse gamma prior on mean forest leaf scale + double b_sigma2_mean = -1.0; // scale parameter for inverse gamma prior on mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) + + // Variance forest parameters + int num_trees_variance = 0; // number of trees in the variance forest + double alpha_variance = 0.5; // alpha parameter for variance forest tree prior + double beta_variance = 2.0; // beta parameter for variance forest tree prior + int min_samples_leaf_variance = 5; // minimum number of samples per leaf for variance forest + bool leaf_constant_variance = true; // whether to use constant leaf model for variance forest + bool exponentiated_leaf_variance = true; // whether to exponentiate leaf predictions for variance forest + int num_features_subsample_variance = 0; // number of features to subsample for each variance forest split (0 means no subsampling) + + // TODO: Random effects parameters ... + + // TODO: Other parameters ... +}; + +struct BARTSamples { + // Posterior samples of training set mean forest predictions (num_samples x n_train, stored column-major) + std::vector mean_forest_predictions_train; + + // Posterior samples of training set variance forest predictions (num_samples x n_train, stored column-major) + std::vector variance_forest_predictions_train; + + // Posterior samples of test set mean forest predictions (num_samples x n_test, stored column-major) + std::vector mean_forest_predictions_test; + + // Posterior samples of test set variance forest predictions (num_samples x n_test, stored column-major) + std::vector variance_forest_predictions_test; + + // Posterior samples of global error variance (num_samples) + std::vector global_error_variance_samples; + + // Posterior samples of leaf scale (num_samples) + std::vector leaf_scale_samples; + + // Pointer to sampled mean forests + std::unique_ptr mean_forests; + + // Pointer to sampled variance forests + std::unique_ptr variance_forests; + + // TODO: Pointer to random effects samples ... +}; + +} // namespace StochTree + +#endif // STOCHTREE_BART_H_ From 7321b457fa4136c56674d0df9a5f15e5d3b2a55d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 9 Apr 2026 19:13:47 -0400 Subject: [PATCH 21/64] Updated debug program to use new BART data structures --- debug/bart_debug.cpp | 93 +++++++++++++++++++++++++++++++--------- include/stochtree/bart.h | 51 ++++++++++++---------- 2 files changed, 100 insertions(+), 44 deletions(-) diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 9a62a012..0d7bd76a 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -15,6 +15,7 @@ * random effects, multivariate leaf, etc.). */ +#include #include #include #include @@ -26,6 +27,7 @@ #include #include +#include #include #include #include @@ -99,25 +101,31 @@ using PostIterFn = std::function; using ReportFn = std::function&, double)>; void run_bart_sampler(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, + StochTree::BARTConfig& config, StochTree::ForestDataset& dataset, StochTree::ColumnVector& residual, std::mt19937& rng, StochTree::ForestDataset& test_dataset, PostIterFn post_iter, ReportFn report_results) { + // Initialize sample outputs + StochTree::BARTSamples bart_samples; + // Single-threaded with default cutpoint grid size (for now) - constexpr int num_threads = 1; - constexpr int cutpoint_grid_size = 100; + int num_threads = config.num_threads; + int cutpoint_grid_size = config.cutpoint_grid_size; // Model parameters for split rule selection and tree sweeps - std::vector feature_types(p, StochTree::FeatureType::kNumeric); - std::vector var_weights(p, 1.0 / p); - std::vector sweep_indices(num_trees); - std::iota(sweep_indices.begin(), sweep_indices.end(), 0); + std::vector feature_types(p); + for (int i = 0; i < p; i++) { + feature_types[i] = static_cast(config.feature_types[i]); + } + std::vector var_weights = config.var_weights_mean; + std::vector sweep_indices = config.sweep_update_indices; // Ephemeral sampler state - StochTree::TreePrior tree_prior(0.95, 2.0, /*min_samples_leaf=*/5); - StochTree::ForestContainer forest_samples(num_trees, /*output_dim=*/1, /*leaf_constant=*/true, /*exponentiated=*/false); - StochTree::TreeEnsemble active_forest(num_trees, 1, true, false); - StochTree::ForestTracker tracker(dataset.GetCovariates(), feature_types, num_trees, n); + StochTree::TreePrior tree_prior(config.alpha_mean, config.beta_mean, /*min_samples_leaf=*/config.min_samples_leaf_mean); + bart_samples.mean_forests = std::make_unique(config.num_trees_mean, /*output_dim=*/config.leaf_dim_mean, /*leaf_constant=*/config.leaf_constant_mean, /*exponentiated=*/config.exponentiated_leaf_mean); + StochTree::TreeEnsemble active_forest(config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean); + StochTree::ForestTracker tracker(dataset.GetCovariates(), feature_types, config.num_trees_mean, n); // Initialize forest and tracker predictions to 0 (after standardization, this is the best initial guess) active_forest.SetLeafValue(0.0); @@ -125,8 +133,17 @@ void run_bart_sampler(int n, int n_test, int p, int num_trees, int num_gfr, int tracker.UpdatePredictions(&active_forest, dataset); // Initialize leaf model and global variance for sampling iterations - StochTree::GaussianConstantLeafModel leaf_model(1.0 / num_trees); - double global_variance = 1.0; + if (config.sigma2_mean_init < 0.0) { + // Data-informed initialization of leaf scale based on variance of the outcome and number of trees, following Chipman et al. (2010) + double y_var = 0.0; + for (int i = 0; i < n; i++) { + y_var += residual.GetData()[i] * residual.GetData()[i]; + } + y_var /= n; + config.sigma2_mean_init = y_var / config.num_trees_mean; + } + StochTree::GaussianConstantLeafModel leaf_model(config.sigma2_mean_init); + double global_variance = config.sigma2_global_init; // Run GFR std::cout << "[GFR] " << num_gfr << " warmup iterations...\n"; @@ -136,7 +153,7 @@ void run_bart_sampler(int n, int n_test, int p, int num_trees, int num_gfr, int StochTree::GFRSampleOneIter< StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>( - active_forest, tracker, forest_samples, leaf_model, + active_forest, tracker, *bart_samples.mean_forests, leaf_model, dataset, residual, tree_prior, rng, var_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, @@ -153,7 +170,7 @@ void run_bart_sampler(int n, int n_test, int p, int num_trees, int num_gfr, int StochTree::MCMCSampleOneIter< StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>( - active_forest, tracker, forest_samples, leaf_model, + active_forest, tracker, *bart_samples.mean_forests, leaf_model, dataset, residual, tree_prior, rng, var_weights, sweep_indices, global_variance, /*keep_forest=*/true, /*pre_initialized=*/true, @@ -164,7 +181,7 @@ void run_bart_sampler(int n, int n_test, int p, int num_trees, int num_gfr, int } // Analyze posterior predictions (column-major, element [j*n_test + i] = sample j, obs i) - report_results(forest_samples.Predict(test_dataset), global_variance); + report_results(bart_samples.mean_forests->Predict(test_dataset), global_variance); } // ---- Scenario 0: homoskedastic constant-leaf BART ------------------- @@ -193,10 +210,17 @@ void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, int nu resid_vec[i] = (data.y[i] - y_bar) / y_std; } + // Load data into BARTData object + StochTree::BARTData bart_data; + bart_data.n_train = n; + bart_data.p = p; + bart_data.X_train = data.X.data(); + bart_data.y_train = resid_vec.data(); + // Initialize dataset and residual vector for sampler StochTree::ForestDataset dataset; - dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/false); - StochTree::ColumnVector residual(resid_vec.data(), n); + dataset.AddCovariates(bart_data.X_train, n, p, /*row_major=*/false); + StochTree::ColumnVector residual(bart_data.y_train, n); // Initialize global error variance model constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior @@ -228,8 +252,19 @@ void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, int nu << " sigma (truth): 1.0\n"; }; + // Initialize BART config (same for GFR warmup and MCMC sampling) + StochTree::BARTConfig config; + config.num_trees_mean = num_trees; + config.a_sigma2_mean = a_sigma; + config.b_sigma2_mean = b_sigma; + config.cutpoint_grid_size = 100; + config.sweep_update_indices.resize(num_trees); + std::iota(config.sweep_update_indices.begin(), config.sweep_update_indices.end(), 0); + config.feature_types = std::vector(p, 0); + config.var_weights_mean = std::vector(p, 1.0 / p); + // Dispatch BART sampler - run_bart_sampler(n, n_test, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, test_dataset, post_iter, report); + run_bart_sampler(n, n_test, p, num_trees, num_gfr, num_mcmc, config, dataset, residual, rng, test_dataset, post_iter, report); } // ---- Scenario 1: constant-leaf probit BART ------------------- @@ -254,10 +289,17 @@ void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, int nu Z_vec[i] = data.y[i] - y_bar; } + // Load data into BARTData object + StochTree::BARTData bart_data; + bart_data.n_train = n; + bart_data.p = p; + bart_data.X_train = data.X.data(); + bart_data.y_train = y_vec.data(); + // Initialize dataset and residual vector for sampler StochTree::ForestDataset dataset; - dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/false); - StochTree::ColumnVector residual(Z_vec.data(), n); + dataset.AddCovariates(bart_data.X_train, n, p, /*row_major=*/false); + StochTree::ColumnVector residual(bart_data.y_train, n); // Lambda function for probit data augmentation sampling step (after each forest sample) auto post_iter = [&](StochTree::ForestTracker& tracker, double&) { @@ -285,8 +327,17 @@ void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, int nu << " sigma (truth): 1.0\n"; }; + // Initialize BART config (same for GFR warmup and MCMC sampling) + StochTree::BARTConfig config; + config.num_trees_mean = num_trees; + config.cutpoint_grid_size = 100; + config.sweep_update_indices.resize(num_trees); + std::iota(config.sweep_update_indices.begin(), config.sweep_update_indices.end(), 0); + config.feature_types = std::vector(p, 0); + config.var_weights_mean = std::vector(p, 1.0 / p); + // Dispatch BART sampler - run_bart_sampler(n, n_test, p, num_trees, num_gfr, num_mcmc, dataset, residual, rng, test_dataset, post_iter, report); + run_bart_sampler(n, n_test, p, num_trees, num_gfr, num_mcmc, config, dataset, residual, rng, test_dataset, post_iter, report); } // ---- Main ----------------------------------------------------------- diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index f86eb5ce..c640b8a9 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -13,42 +13,42 @@ namespace StochTree { struct BARTData { // Train set covariates - const double* X_train; + double* X_train; int n_train = 0; int p = 0; // Test set covariates - const double* X_test; + double* X_test; int n_test = 0; // Train set outcome - const double* y_train; + double* y_train; // Basis for leaf regression - const double* basis_train; - const double* basis_test; + double* basis_train; + double* basis_test; int basis_dim = 0; // Observation weights - const double* obs_weights_train; - const double* obs_weights_test; + double* obs_weights_train; + double* obs_weights_test; // Random effects - const int* rfx_group_ids_train; - const int* rfx_group_ids_test; - const double* rfx_basis_train; - const double* rfx_basis_test; + int* rfx_group_ids_train; + int* rfx_group_ids_test; + double* rfx_basis_train; + double* rfx_basis_test; int rfx_num_groups = 0; int rfx_basis_dim = 0; - - // Feature types encoded as integers (e.g. 0 = continuous, 1 = categorical, etc.) - const int* feature_types; }; struct BARTConfig { // High level parameters - bool standardize_outcome = true; // whether to standardize the outcome before fitting and unstandardize predictions after - int num_threads = 1; // number of threads to use for sampling + bool standardize_outcome = true; // whether to standardize the outcome before fitting and unstandardize predictions after + int num_threads = 1; // number of threads to use for sampling + int cutpoint_grid_size = 100; // number of cutpoints to consider for each covariate when sampling splits + std::vector feature_types; // feature types for each covariate (should be same length as number of covariates in the dataset), where 0 = continuous, 1 = categorical + std::vector sweep_update_indices; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) // Global error variance parameters double a_sigma2_global = 0.0; // shape parameter for inverse gamma prior on global error variance @@ -61,19 +61,24 @@ struct BARTConfig { double beta_mean = 2.0; // beta parameter for mean forest tree prior int min_samples_leaf_mean = 5; // minimum number of samples per leaf for mean forest bool leaf_constant_mean = true; // whether to use constant leaf model for mean forest + int leaf_dim_mean = 1; // dimension of the leaf for mean forest bool exponentiated_leaf_mean = false; // whether to exponentiate leaf predictions for mean forest int num_features_subsample_mean = 0; // number of features to subsample for each mean forest split (0 means no subsampling) double a_sigma2_mean = 3.0; // shape parameter for inverse gamma prior on mean forest leaf scale double b_sigma2_mean = -1.0; // scale parameter for inverse gamma prior on mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) + double sigma2_mean_init = -1.0; // initial value of mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) + std::vector var_weights_mean; // variable weights for mean forest splits (should be same length as number of covariates in the dataset) // Variance forest parameters - int num_trees_variance = 0; // number of trees in the variance forest - double alpha_variance = 0.5; // alpha parameter for variance forest tree prior - double beta_variance = 2.0; // beta parameter for variance forest tree prior - int min_samples_leaf_variance = 5; // minimum number of samples per leaf for variance forest - bool leaf_constant_variance = true; // whether to use constant leaf model for variance forest - bool exponentiated_leaf_variance = true; // whether to exponentiate leaf predictions for variance forest - int num_features_subsample_variance = 0; // number of features to subsample for each variance forest split (0 means no subsampling) + int num_trees_variance = 0; // number of trees in the variance forest + double alpha_variance = 0.5; // alpha parameter for variance forest tree prior + double beta_variance = 2.0; // beta parameter for variance forest tree prior + int min_samples_leaf_variance = 5; // minimum number of samples per leaf for variance forest + bool leaf_constant_variance = true; // whether to use constant leaf model for variance forest + int leaf_dim_variance = 1; // dimension of the leaf for variance forest (should be 1 if leaf_constant_variance=true) + bool exponentiated_leaf_variance = true; // whether to exponentiate leaf predictions for variance forest + int num_features_subsample_variance = 0; // number of features to subsample for each variance forest split (0 means no subsampling) + std::vector var_weights_variance; // variable weights for variance forest splits (should be same length as number of covariates in the dataset) // TODO: Random effects parameters ... From b3cced0ae9d6e4cb98824b19b520f8337191738d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 10 Apr 2026 09:21:52 -0400 Subject: [PATCH 22/64] Initial (non-working) implementation of BARTSampler --- .vscode/launch.json | 8 +- CMakeLists.txt | 19 +- debug/bart_debug.cpp | 334 +++++--------------- debug/bcf_debug.cpp | 503 +++++++++---------------------- include/stochtree/bart.h | 52 ++-- include/stochtree/bart_sampler.h | 80 +++++ src/bart_sampler.cpp | 181 +++++++++++ 7 files changed, 538 insertions(+), 639 deletions(-) create mode 100644 include/stochtree/bart_sampler.h create mode 100644 src/bart_sampler.cpp diff --git a/.vscode/launch.json b/.vscode/launch.json index 8ae6485b..fa7e941c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,7 +6,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "0", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (dev-quick)" }, @@ -33,7 +33,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "0", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev-quick)" @@ -63,7 +63,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bart_debug", - "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "0", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (release-drivers)" }, @@ -72,7 +72,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bart_debug", - "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "0", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (release-drivers)" diff --git a/CMakeLists.txt b/CMakeLists.txt index 753c1975..cbf8b53c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,6 +132,7 @@ set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/build) file( GLOB SOURCES + src/bart_sampler.cpp src/container.cpp src/cutpoint_candidates.cpp src/data.cpp @@ -224,14 +225,14 @@ if(BUILD_DEBUG_TARGETS) target_link_libraries(bart_debug PRIVATE stochtree_objs) endif() - # BCF debug driver - add_executable(bcf_debug debug/bcf_debug.cpp) - if(USE_OPENMP) - target_include_directories(bcf_debug PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR}) - target_link_libraries(bcf_debug PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY}) - else() - target_include_directories(bcf_debug PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR}) - target_link_libraries(bcf_debug PRIVATE stochtree_objs) - endif() + # BCF debug driver (temporarily disabled) + # add_executable(bcf_debug debug/bcf_debug.cpp) + # if(USE_OPENMP) + # target_include_directories(bcf_debug PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR}) + # target_link_libraries(bcf_debug PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY}) + # else() + # target_include_directories(bcf_debug PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR}) + # target_link_libraries(bcf_debug PRIVATE stochtree_objs) + # endif() endif() diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 0d7bd76a..2251fda4 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -15,24 +15,16 @@ * random effects, multivariate leaf, etc.). */ +// TODO: Replace with #include once Task 1.3 is complete. #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include -#include #include #include #include #include +#include "stochtree/meta.h" static constexpr double kPi = 3.14159265358979323846; @@ -49,8 +41,8 @@ struct ProbitDataset { std::vector Z; }; -// DGP: y ~ sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) -RegressionDataset generate_constant_leaf_regression_data(int n, int p, std::mt19937& rng) { +// DGP: y = sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) +static RegressionDataset generate_regression_data(int n, int p, std::mt19937& rng) { std::uniform_real_distribution unif(0.0, 1.0); std::normal_distribution normal(0.0, 1.0); RegressionDataset d; @@ -64,11 +56,10 @@ RegressionDataset generate_constant_leaf_regression_data(int n, int p, std::mt19 return d; } -// DGP -// --- -// Z ~ sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) -// y = 1{Z > 0} -ProbitDataset generate_probit_data(int n, int p, std::mt19937& rng) { +// DGP: +// Z = sin(2*pi*x1) + 0.5*x2 - 1.5*x3 + N(0,1) +// y = 1{Z > 0} +static ProbitDataset generate_probit_data(int n, int p, std::mt19937& rng) { std::uniform_real_distribution unif(0.0, 1.0); std::normal_distribution normal(0.0, 1.0); ProbitDataset d; @@ -85,265 +76,104 @@ ProbitDataset generate_probit_data(int n, int p, std::mt19937& rng) { return d; } -// ---- Shared sampler loop -------------------------------------------- +// ---- Reporter -------------------------------------------------------- // -// Runs GFR warmup then MCMC sampling, both using the same forest/leaf/variance -// setup. The two scenario-specific hooks are: -// -// post_iter(tracker, global_variance) — called after every forest sample in -// both GFR and MCMC (e.g. sample global variance, or augment latent Z). -// -// report_results(preds, global_variance) — called once after all samples are -// collected; receives the flat column-major predictions matrix and the -// final global variance value. - -using PostIterFn = std::function; -using ReportFn = std::function&, double)>; - -void run_bart_sampler(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, - StochTree::BARTConfig& config, - StochTree::ForestDataset& dataset, - StochTree::ColumnVector& residual, std::mt19937& rng, - StochTree::ForestDataset& test_dataset, - PostIterFn post_iter, ReportFn report_results) { - // Initialize sample outputs - StochTree::BARTSamples bart_samples; - - // Single-threaded with default cutpoint grid size (for now) - int num_threads = config.num_threads; - int cutpoint_grid_size = config.cutpoint_grid_size; - - // Model parameters for split rule selection and tree sweeps - std::vector feature_types(p); - for (int i = 0; i < p; i++) { - feature_types[i] = static_cast(config.feature_types[i]); +// Reads directly from BARTSamples (already un-standardized by BARTSamplerFit). +// test_ref is the comparison target on the original outcome scale. + +static void report_bart(const StochTree::BARTSamples& samples, + const std::vector& test_ref, + const char* scenario_name) { + const int num_samples = samples.num_samples; + const int n_test = samples.num_test; + double rmse_sum = 0.0; + for (int i = 0; i < n_test; i++) { + double mu_hat = 0.0; + for (int j = 0; j < num_samples; j++) + mu_hat += samples.mean_forest_predictions_test[static_cast(j * n_test + i)] / num_samples; + double err = mu_hat - test_ref[i]; + rmse_sum += err * err; } - std::vector var_weights = config.var_weights_mean; - std::vector sweep_indices = config.sweep_update_indices; - - // Ephemeral sampler state - StochTree::TreePrior tree_prior(config.alpha_mean, config.beta_mean, /*min_samples_leaf=*/config.min_samples_leaf_mean); - bart_samples.mean_forests = std::make_unique(config.num_trees_mean, /*output_dim=*/config.leaf_dim_mean, /*leaf_constant=*/config.leaf_constant_mean, /*exponentiated=*/config.exponentiated_leaf_mean); - StochTree::TreeEnsemble active_forest(config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean); - StochTree::ForestTracker tracker(dataset.GetCovariates(), feature_types, config.num_trees_mean, n); - - // Initialize forest and tracker predictions to 0 (after standardization, this is the best initial guess) - active_forest.SetLeafValue(0.0); - UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, false, std::minus()); - tracker.UpdatePredictions(&active_forest, dataset); - - // Initialize leaf model and global variance for sampling iterations - if (config.sigma2_mean_init < 0.0) { - // Data-informed initialization of leaf scale based on variance of the outcome and number of trees, following Chipman et al. (2010) - double y_var = 0.0; - for (int i = 0; i < n; i++) { - y_var += residual.GetData()[i] * residual.GetData()[i]; - } - y_var /= n; - config.sigma2_mean_init = y_var / config.num_trees_mean; + std::cout << "\n" + << scenario_name << ":\n" + << " RMSE (test): " << std::sqrt(rmse_sum / n_test) << "\n"; + if (!samples.global_error_variance_samples.empty()) { + std::cout << " sigma (last): " << std::sqrt(samples.global_error_variance_samples.back()) << "\n" + << " sigma (truth): 1.0\n"; } - StochTree::GaussianConstantLeafModel leaf_model(config.sigma2_mean_init); - double global_variance = config.sigma2_global_init; - - // Run GFR - std::cout << "[GFR] " << num_gfr << " warmup iterations...\n"; - bool pre_initialized = true; - for (int i = 0; i < num_gfr; i++) { - // Sample forest - StochTree::GFRSampleOneIter< - StochTree::GaussianConstantLeafModel, - StochTree::GaussianConstantSuffStat>( - active_forest, tracker, *bart_samples.mean_forests, leaf_model, - dataset, residual, tree_prior, rng, - var_weights, sweep_indices, global_variance, feature_types, - cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, - /*backfitting=*/true, /*num_features_subsample=*/p, num_threads); - - // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) - post_iter(tracker, global_variance); - } - - // Run MCMC - std::cout << "[MCMC] " << num_mcmc << " sampling iterations...\n"; - for (int i = 0; i < num_mcmc; i++) { - // Sample forest - StochTree::MCMCSampleOneIter< - StochTree::GaussianConstantLeafModel, - StochTree::GaussianConstantSuffStat>( - active_forest, tracker, *bart_samples.mean_forests, leaf_model, - dataset, residual, tree_prior, rng, - var_weights, sweep_indices, global_variance, - /*keep_forest=*/true, /*pre_initialized=*/true, - /*backfitting=*/true, num_threads); - - // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) - post_iter(tracker, global_variance); - } - - // Analyze posterior predictions (column-major, element [j*n_test + i] = sample j, obs i) - report_results(bart_samples.mean_forests->Predict(test_dataset), global_variance); } -// ---- Scenario 0: homoskedastic constant-leaf BART ------------------- +// ---- Scenario 0: homoskedastic constant-leaf BART -------------------- -void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { - // Allow seed to be non-deterministic if set to sentinel value of -1 - int rng_seed; - if (seed == -1) { - std::random_device rd; - rng_seed = rd(); - } else { - rng_seed = seed; - } - std::mt19937 rng(rng_seed); +static void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, int seed) { + std::mt19937 rng(seed < 0 ? std::random_device{}() : static_cast(seed)); + RegressionDataset train = generate_regression_data(n, p, rng); + RegressionDataset test = generate_regression_data(n_test, p, rng); - // Generate data - RegressionDataset data = generate_constant_leaf_regression_data(n, p, rng); - double y_bar = std::accumulate(data.y.begin(), data.y.end(), 0.0) / data.y.size(); - double y_std = 0; - for (int i = 0; i < n; i++) { - y_std += (data.y[i] - y_bar) * (data.y[i] - y_bar); - } - y_std = std::sqrt(y_std / n); - std::vector resid_vec(data.y.size()); - for (std::size_t i = 0; i < data.y.size(); i++) { - resid_vec[i] = (data.y[i] - y_bar) / y_std; - } + StochTree::BARTData data; + data.X_train = train.X.data(); + data.y_train = train.y.data(); + data.n_train = n; + data.p = p; + data.X_test = test.X.data(); + data.n_test = n_test; - // Load data into BARTData object - StochTree::BARTData bart_data; - bart_data.n_train = n; - bart_data.p = p; - bart_data.X_train = data.X.data(); - bart_data.y_train = resid_vec.data(); - - // Initialize dataset and residual vector for sampler - StochTree::ForestDataset dataset; - dataset.AddCovariates(bart_data.X_train, n, p, /*row_major=*/false); - StochTree::ColumnVector residual(bart_data.y_train, n); - - // Initialize global error variance model - constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior - StochTree::GlobalHomoskedasticVarianceModel var_model; - - // Lambda function for sampling global error variance after each forest sample - auto post_iter = [&](StochTree::ForestTracker&, double& global_variance) { - global_variance = var_model.SampleVarianceParameter(residual.GetData(), a_sigma, b_sigma, rng); - }; - - // Generate test data and build test dataset - RegressionDataset test_data = generate_constant_leaf_regression_data(n_test, p, rng); - StochTree::ForestDataset test_dataset; - test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/false); - - // Lambda function for reporting test-set RMSE and last draw of global error variance model - auto report = [&](const std::vector& preds, double global_variance) { - double rmse_sum = 0.0; - for (int i = 0; i < n_test; i++) { - double mu_hat = 0.0; - for (int j = 0; j < num_mcmc; j++) - mu_hat += preds[static_cast(j * n_test + i)] / num_mcmc; - double err = (mu_hat * y_std + y_bar) - test_data.y[i]; - rmse_sum += err * err; - } - std::cout << "\nScenario 0 (Homoskedastic BART):\n" - << " RMSE (test): " << std::sqrt(rmse_sum / n_test) << "\n" - << " sigma (last sample): " << std::sqrt(global_variance) * y_std << "\n" - << " sigma (truth): 1.0\n"; - }; - - // Initialize BART config (same for GFR warmup and MCMC sampling) StochTree::BARTConfig config; config.num_trees_mean = num_trees; - config.a_sigma2_mean = a_sigma; - config.b_sigma2_mean = b_sigma; - config.cutpoint_grid_size = 100; - config.sweep_update_indices.resize(num_trees); - std::iota(config.sweep_update_indices.begin(), config.sweep_update_indices.end(), 0); - config.feature_types = std::vector(p, 0); + config.random_seed = seed; + config.probit = false; + config.standardize_outcome = true; + config.sample_sigma2_global = false; config.var_weights_mean = std::vector(p, 1.0 / p); + config.feature_types = std::vector(p, StochTree::FeatureType::kNumeric); + config.sweep_update_indices = std::vector(num_trees, 0); + std::iota(config.sweep_update_indices.begin(), config.sweep_update_indices.end(), 0); - // Dispatch BART sampler - run_bart_sampler(n, n_test, p, num_trees, num_gfr, num_mcmc, config, dataset, residual, rng, test_dataset, post_iter, report); + StochTree::BARTSamples samples; + StochTree::BARTSampler sampler(samples, config, data); + sampler.run_gfr(samples, config, data, rng, num_gfr, true); + sampler.run_mcmc(samples, config, data, rng, 0, 1, num_mcmc); + report_bart(samples, test.y, "Scenario 0 (Homoskedastic BART)"); } -// ---- Scenario 1: constant-leaf probit BART ------------------- - -void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, int seed = 1234) { - // Allow seed to be non-deterministic if set to sentinel value of -1 - int rng_seed; - if (seed == -1) { - std::random_device rd; - rng_seed = rd(); - } else { - rng_seed = seed; - } - std::mt19937 rng(rng_seed); - - // Generate data - ProbitDataset data = generate_probit_data(n, p, rng); - double y_bar = std::accumulate(data.y.begin(), data.y.end(), 0.0) / data.y.size(); - std::vector y_vec = data.y; - std::vector Z_vec(n); - for (int i = 0; i < n; i++) { - Z_vec[i] = data.y[i] - y_bar; - } - - // Load data into BARTData object - StochTree::BARTData bart_data; - bart_data.n_train = n; - bart_data.p = p; - bart_data.X_train = data.X.data(); - bart_data.y_train = y_vec.data(); +// ---- Scenario 1: constant-leaf probit BART --------------------------- - // Initialize dataset and residual vector for sampler - StochTree::ForestDataset dataset; - dataset.AddCovariates(bart_data.X_train, n, p, /*row_major=*/false); - StochTree::ColumnVector residual(bart_data.y_train, n); +static void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, int num_mcmc, int seed) { + std::mt19937 rng(seed < 0 ? std::random_device{}() : static_cast(seed)); + ProbitDataset train = generate_probit_data(n, p, rng); + ProbitDataset test = generate_probit_data(n_test, p, rng); - // Lambda function for probit data augmentation sampling step (after each forest sample) - auto post_iter = [&](StochTree::ForestTracker& tracker, double&) { - StochTree::sample_probit_latent_outcome( - rng, y_vec.data(), tracker.GetSumPredictions(), residual.GetData().data(), y_bar, n); - }; + StochTree::BARTData data; + data.X_train = train.X.data(); + data.y_train = train.y.data(); + data.n_train = n; + data.p = p; + data.X_test = test.X.data(); + data.n_test = n_test; - // Generate test data and build test dataset - ProbitDataset test_data = generate_probit_data(n_test, p, rng); - StochTree::ForestDataset test_dataset; - test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/false); - - // Lambda function for reporting test-set RMSE - auto report = [&](const std::vector& preds, double global_variance) { - double rmse_sum = 0.0; - for (int i = 0; i < n_test; i++) { - double mu_hat = 0.0; - for (int j = 0; j < num_mcmc; j++) - mu_hat += preds[static_cast(j * n_test + i)] / num_mcmc; - double err = (mu_hat + y_bar) - test_data.Z[i]; - rmse_sum += err * err; - } - std::cout << "\nScenario 1 (Probit BART):\n" - << " RMSE (test): " << std::sqrt(rmse_sum / n_test) << "\n" - << " sigma (truth): 1.0\n"; - }; - - // Initialize BART config (same for GFR warmup and MCMC sampling) StochTree::BARTConfig config; config.num_trees_mean = num_trees; - config.cutpoint_grid_size = 100; - config.sweep_update_indices.resize(num_trees); - std::iota(config.sweep_update_indices.begin(), config.sweep_update_indices.end(), 0); - config.feature_types = std::vector(p, 0); + config.random_seed = seed; + config.probit = true; + config.standardize_outcome = true; + config.sample_sigma2_global = false; config.var_weights_mean = std::vector(p, 1.0 / p); + config.feature_types = std::vector(p, StochTree::FeatureType::kNumeric); + config.sweep_update_indices = std::vector(num_trees, 0); + std::iota(config.sweep_update_indices.begin(), config.sweep_update_indices.end(), 0); - // Dispatch BART sampler - run_bart_sampler(n, n_test, p, num_trees, num_gfr, num_mcmc, config, dataset, residual, rng, test_dataset, post_iter, report); + StochTree::BARTSamples samples; + StochTree::BARTSampler sampler(samples, config, data); + sampler.run_gfr(samples, config, data, rng, num_gfr, true); + sampler.run_mcmc(samples, config, data, rng, 0, 1, num_mcmc); + // Predictions are on latent scale (= raw + y_bar); compare to true latent Z. + report_bart(samples, test.Z, "Scenario 1 (Probit BART)"); } // ---- Main ----------------------------------------------------------- int main(int argc, char** argv) { - int scenario = 1; + int scenario = 0; int n = 500; int n_test = 100; int p = 5; @@ -391,7 +221,7 @@ int main(int argc, char** argv) { break; default: std::cerr << "Unknown scenario " << scenario - << ". Available scenarios: 0 (Homoskedastic BART), 1 (Probit BART)\n"; + << ". Available: 0 (Homoskedastic BART), 1 (Probit BART)\n"; return 1; } return 0; diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index ff91552c..ec68e4db 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -1,38 +1,35 @@ /* - * BCF debug program + * BCF debug program. The first CLI argument selects the scenario (default: 0). * - * Usage: bcf_debug [--scenario N] [--n N] [--n_test N] [--p N] [--num_trees_mu N] [--num_trees_tau N] + * Usage: bcf_debug [--scenario N] [--n N] [--n_test N] [--p N] + * [--num_trees_mu N] [--num_trees_tau N] * [--num_gfr N] [--num_mcmc N] [--seed N] * * 0 Two-forest BCF: constant-leaf mu, univariate-leaf tau (Z as basis) - * DGP: mu(x) = 2*sin(pi*x1) + 0.5*x2 + * DGP: mu(x) = 2*sin(pi*x1) + 0.5*x2 * tau(x) = 1 + x3 * z ~ Bernoulli(0.5) * y = mu(x) + tau(x)*z + N(0, 0.5^2) * - * 1 Two-forest BCF: constant-leaf mu, univariate-leaf tau (Z as basis) - * DGP: mu(x) = 2*sin(pi*x1) + 0.5*x2 + * 1 Two-forest BCF with probit link: constant-leaf mu, univariate-leaf tau + * DGP: mu(x) = 2*sin(pi*x1) + 0.5*x2 * tau(x) = 1 + x3 * z ~ Bernoulli(0.5) * W = mu(x) + tau(x)*z + N(0, 1) * y = 1{W > 0} * - * Add scenarios here as the BCFSampler API develops (heteroskedastic, - * random effects, propensity weighting, etc.). + * Add scenarios here as the BCFSampler API develops (propensity covariate, + * adaptive coding, random effects, etc.). + * + * TODO: Replace the stub include with once + * Task 2.2 (src/bcf_sampler.cpp) is complete. */ -#include -#include -#include -#include -#include -#include -#include -#include +// TODO: Replace with #include once Task 2.2 is complete. +#include "bcf_sampler_stub.h" #include #include -#include #include #include #include @@ -58,36 +55,35 @@ struct ProbitBCFDataset { std::vector tau_true; }; -SimpleBCFDataset generate_simple_bcf_data(int n, int p, std::mt19937& rng) { +// DGP: mu(x) = 2*sin(pi*x1) + 0.5*x2, tau(x) = 1 + x3 +// z ~ Bernoulli(0.5), y = mu + tau*z + N(0, 0.25) +static SimpleBCFDataset generate_simple_bcf_data(int n, int p, std::mt19937& rng) { std::uniform_real_distribution unif(0.0, 1.0); std::normal_distribution normal(0.0, 1.0); std::bernoulli_distribution bern(0.5); - SimpleBCFDataset d; d.X.resize(n * p); d.y.resize(n); d.z.resize(n); d.mu_true.resize(n); d.tau_true.resize(n); - for (int i = 0; i < n; i++) for (int j = 0; j < p; j++) d.X[j * n + i] = unif(rng); - for (int i = 0; i < n; i++) { - d.z[i] = bern(rng) ? 1.0 : 0.0; - d.mu_true[i] = 2.0 * std::sin(kPi * d.X[i]) + 0.5 * d.X[1 * n + i]; + d.z[i] = bern(rng) ? 1.0 : 0.0; + d.mu_true[i] = 2.0 * std::sin(kPi * d.X[i]) + 0.5 * d.X[1 * n + i]; d.tau_true[i] = 1.0 + d.X[2 * n + i]; - d.y[i] = d.mu_true[i] + d.tau_true[i] * d.z[i] + 0.5 * normal(rng); + d.y[i] = d.mu_true[i] + d.tau_true[i] * d.z[i] + 0.5 * normal(rng); } return d; } -ProbitBCFDataset generate_probit_bcf_data(int n, int p, std::mt19937& rng) { +// DGP: same mu/tau; W = mu + tau*z + N(0,1); y = 1{W > 0} +static ProbitBCFDataset generate_probit_bcf_data(int n, int p, std::mt19937& rng) { std::uniform_real_distribution unif(0.0, 1.0); std::normal_distribution normal(0.0, 1.0); std::bernoulli_distribution bern(0.5); - ProbitBCFDataset d; d.X.resize(n * p); d.y.resize(n); @@ -95,357 +91,158 @@ ProbitBCFDataset generate_probit_bcf_data(int n, int p, std::mt19937& rng) { d.mu_true.resize(n); d.tau_true.resize(n); d.latent_outcome.resize(n); - for (int i = 0; i < n; i++) for (int j = 0; j < p; j++) d.X[j * n + i] = unif(rng); - for (int i = 0; i < n; i++) { - d.z[i] = bern(rng) ? 1.0 : 0.0; - d.mu_true[i] = 2.0 * std::sin(kPi * d.X[i]) + 0.5 * d.X[1 * n + i]; - d.tau_true[i] = 1.0 + d.X[2 * n + i]; + d.z[i] = bern(rng) ? 1.0 : 0.0; + d.mu_true[i] = 2.0 * std::sin(kPi * d.X[i]) + 0.5 * d.X[1 * n + i]; + d.tau_true[i] = 1.0 + d.X[2 * n + i]; d.latent_outcome[i] = d.mu_true[i] + d.tau_true[i] * d.z[i] + normal(rng); - d.y[i] = (d.latent_outcome[i] > 0.0) ? 1.0 : 0.0; + d.y[i] = (d.latent_outcome[i] > 0.0) ? 1.0 : 0.0; } return d; } -// ---- Shared sampler loop -------------------------------------------- -// -// Runs alternating mu/tau GFR warmup then MCMC, sharing a single residual. -// The two scenario-specific hooks are: +// ---- Reporter -------------------------------------------------------- // -// post_iter(mu_tracker, global_variance) — called after each full mu+tau -// iteration (e.g. sample global variance). -// -// report_results(mu_preds, tau_preds, global_variance) — called once after -// all samples are collected; receives column-major prediction matrices -// and the final global variance value. - -using PostIterFn = std::function; -using BCFReportFn = std::function&, const std::vector&, double)>; - -void run_bcf_sampler(int n, int n_test, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, - StochTree::ForestDataset& dataset, - StochTree::ColumnVector& residual, std::mt19937& rng, - StochTree::ForestDataset& test_dataset, - PostIterFn post_iter, BCFReportFn report_results) { - // Single-threaded with default cutpoint grid size (for now) - constexpr int num_threads = 1; - constexpr int cutpoint_grid_size = 100; - - // Model parameters for split rule selection and tree sweeps - std::vector feature_types(p, StochTree::FeatureType::kNumeric); - std::vector var_weights(p, 1.0 / p); - std::vector sweep_indices_mu(num_trees_mu); - std::iota(sweep_indices_mu.begin(), sweep_indices_mu.end(), 0); - std::vector sweep_indices_tau(num_trees_tau); - std::iota(sweep_indices_tau.begin(), sweep_indices_tau.end(), 0); - - // Ephemeral sampler state - // Mu forest: constant-leaf - StochTree::TreePrior mu_tree_prior(0.95, 2.0, /*min_samples_leaf=*/5); - StochTree::ForestContainer mu_samples(num_trees_mu, /*output_dim=*/1, /*leaf_constant=*/true, /*exponentiated=*/false); - StochTree::TreeEnsemble mu_forest(num_trees_mu, 1, true, false); - StochTree::ForestTracker mu_tracker(dataset.GetCovariates(), feature_types, num_trees_mu, n); - StochTree::GaussianConstantLeafModel mu_leaf_model(1.0 / num_trees_mu); - - // Tau forest: univariate regression leaf (prediction = leaf_param * z) - StochTree::TreePrior tau_tree_prior(0.5, 2.0, /*min_samples_leaf=*/5); - StochTree::ForestContainer tau_samples(num_trees_tau, /*output_dim=*/1, /*leaf_constant=*/false, /*exponentiated=*/false); - StochTree::TreeEnsemble tau_forest(num_trees_tau, 1, false, false); - StochTree::ForestTracker tau_tracker(dataset.GetCovariates(), feature_types, num_trees_tau, n); - StochTree::GaussianUnivariateRegressionLeafModel tau_leaf_model(1.0 / num_trees_tau); - - // Initialize mu forest and tracker predictions to 0 - mu_forest.SetLeafValue(0.0); - UpdateResidualEntireForest(mu_tracker, dataset, residual, &mu_forest, false, std::minus()); - mu_tracker.UpdatePredictions(&mu_forest, dataset); - - // Initial tau forest and tracker predictions to 0 - tau_forest.SetLeafValue(0.0); - UpdateResidualEntireForest(tau_tracker, dataset, residual, &tau_forest, false, std::minus()); - tau_tracker.UpdatePredictions(&tau_forest, dataset); - - // Model predictions - std::vector outcome_preds(n, 0.0); - - // Initialize global error variance to 1 (output is standardized) - double global_variance = 1.0; - - // Run GFR - std::cout << "[GFR] " << num_gfr << " warmup iterations...\n"; - bool pre_initialized = true; - for (int i = 0; i < num_gfr; i++) { - // Sample mu forest - StochTree::GFRSampleOneIter< - StochTree::GaussianConstantLeafModel, - StochTree::GaussianConstantSuffStat>( - mu_forest, mu_tracker, mu_samples, mu_leaf_model, - dataset, residual, mu_tree_prior, rng, - var_weights, sweep_indices_mu, global_variance, feature_types, - cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, - /*backfitting=*/true, /*num_features_subsample=*/p, num_threads); - - // Sample tau forest - StochTree::GFRSampleOneIter< - StochTree::GaussianUnivariateRegressionLeafModel, - StochTree::GaussianUnivariateRegressionSuffStat>( - tau_forest, tau_tracker, tau_samples, tau_leaf_model, - dataset, residual, tau_tree_prior, rng, - var_weights, sweep_indices_tau, global_variance, feature_types, - cutpoint_grid_size, /*keep_forest=*/false, pre_initialized, - /*backfitting=*/true, /*num_features_subsample=*/p, num_threads); - - // Update predictions and residual for post-iteration hook (e.g. global variance sampling, probit data augmentation, etc.) - for (int j = 0; j < n; j++) { - outcome_preds[j] = mu_tracker.GetSamplePrediction(j) + tau_tracker.GetSamplePrediction(j); +// Reads directly from BCFSamples (already un-standardized by BCFSamplerFit). +// mu_ref — true prognostic function (original outcome scale) +// tau_ref — true CATE (treatment effect scale, no y_bar offset) +// y_ref — true outcome or latent outcome for comparison + +static void report_bcf(const StochTree::BCFSamples& samples, + const std::vector& mu_ref, + const std::vector& tau_ref, + const std::vector& y_ref, + const char* scenario_name) { + const int num_samples = samples.num_samples; + const int n_test = samples.n_test; + double mu_rmse_sum = 0.0, tau_rmse_sum = 0.0, y_rmse_sum = 0.0; + for (int i = 0; i < n_test; i++) { + double mu_hat = 0.0, tau_hat = 0.0, y_hat = 0.0; + for (int j = 0; j < num_samples; j++) { + const auto k = static_cast(j * n_test + i); + mu_hat += samples.mu_hat_test[k] / num_samples; + tau_hat += samples.tau_hat_test[k] / num_samples; + y_hat += samples.y_hat_test[k] / num_samples; } - - // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) - post_iter(outcome_preds.data(), global_variance); + mu_rmse_sum += (mu_hat - mu_ref[i]) * (mu_hat - mu_ref[i]); + tau_rmse_sum += (tau_hat - tau_ref[i]) * (tau_hat - tau_ref[i]); + y_rmse_sum += (y_hat - y_ref[i]) * (y_hat - y_ref[i]); } - - // Run MCMC - std::cout << "[MCMC] " << num_mcmc << " sampling iterations...\n"; - for (int i = 0; i < num_mcmc; i++) { - // Sample mu forest - StochTree::MCMCSampleOneIter< - StochTree::GaussianConstantLeafModel, - StochTree::GaussianConstantSuffStat>( - mu_forest, mu_tracker, mu_samples, mu_leaf_model, - dataset, residual, mu_tree_prior, rng, - var_weights, sweep_indices_mu, global_variance, - /*keep_forest=*/true, /*pre_initialized=*/true, - /*backfitting=*/true, num_threads); - - // Sample tau forest - StochTree::MCMCSampleOneIter< - StochTree::GaussianUnivariateRegressionLeafModel, - StochTree::GaussianUnivariateRegressionSuffStat>( - tau_forest, tau_tracker, tau_samples, tau_leaf_model, - dataset, residual, tau_tree_prior, rng, - var_weights, sweep_indices_tau, global_variance, - /*keep_forest=*/true, /*pre_initialized=*/true, - /*backfitting=*/true, num_threads); - - // Update predictions and residual for post-iteration hook (e.g. global variance sampling, probit data augmentation, etc.) - for (int j = 0; j < n; j++) { - outcome_preds[j] = mu_tracker.GetSamplePrediction(j) + tau_tracker.GetSamplePrediction(j); - } - - // Sample other model parameters (e.g. global variance, probit data augmentation, etc.) - post_iter(outcome_preds.data(), global_variance); + std::cout << "\n" << scenario_name << ":\n" + << " mu RMSE (test): " << std::sqrt(mu_rmse_sum / n_test) << "\n" + << " tau RMSE (test): " << std::sqrt(tau_rmse_sum / n_test) << "\n" + << " y RMSE (test): " << std::sqrt(y_rmse_sum / n_test) << "\n"; + if (!samples.global_error_variance_samples.empty()) { + std::cout << " sigma (last): " + << std::sqrt(samples.global_error_variance_samples.back()) << "\n"; } - - // Analyze posterior predictions (column-major, element [j*n_test + i] = sample j, obs i) - report_results(mu_samples.Predict(test_dataset), tau_samples.PredictRaw(test_dataset), global_variance); } -// ---- Scenario 0: constant-leaf mu + univariate-leaf tau (Z basis) --- - -void run_scenario_0(int n, int n_test, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, int seed = 42) { - // Allow seed to be non-deterministic if set to sentinel value of -1 - int rng_seed; - if (seed == -1) { - std::random_device rd; - rng_seed = rd(); - } else { - rng_seed = seed; - } - std::mt19937 rng(rng_seed); - - // Generate data and standardize outcome - SimpleBCFDataset data = generate_simple_bcf_data(n, p, rng); - double y_bar = std::accumulate(data.y.begin(), data.y.end(), 0.0) / data.y.size(); - double y_std = 0; - for (int i = 0; i < n; i++) { - y_std += (data.y[i] - y_bar) * (data.y[i] - y_bar); - } - y_std = std::sqrt(y_std / n); - std::vector resid_vec(n); - for (int i = 0; i < n; i++) { - resid_vec[i] = (data.y[i] - y_bar) / y_std; - } - - // Shared dataset: only tau forest uses the Z basis for leaf regression - StochTree::ForestDataset dataset; - dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/false); - dataset.AddBasis(data.z.data(), n, /*num_col=*/1, /*row_major=*/false); - - // Shared residual - StochTree::ColumnVector residual(resid_vec.data(), n); - - // Global error variance model - constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior - StochTree::GlobalHomoskedasticVarianceModel var_model; - - // Lambda function for sampling global error variance after each mu+tau step - auto post_iter = [&](double* outcome_preds, double& global_variance) { - global_variance = var_model.SampleVarianceParameter(residual.GetData(), a_sigma, b_sigma, rng); - }; - - // Generate test data and build test datasets - SimpleBCFDataset test_data = generate_simple_bcf_data(n_test, p, rng); - - // Test dataset: covariates + actual treatment z (for y prediction) - StochTree::ForestDataset test_dataset; - test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/false); - test_dataset.AddBasis(test_data.z.data(), n_test, /*num_col=*/1, /*row_major=*/false); - - // Lambda function for reporting mu/tau RMSE and last draw of global error variance - auto report = [&](const std::vector& mu_preds, const std::vector& tau_preds, - double global_variance) { - double mu_rmse_sum = 0.0, tau_rmse_sum = 0.0, y_rmse_sum = 0.0; - - for (int i = 0; i < n_test; i++) { - double mu_hat = 0.0; - for (int j = 0; j < num_mcmc; j++) - mu_hat += mu_preds[static_cast(j * n_test + i)] / num_mcmc; - mu_rmse_sum += (mu_hat * y_std + y_bar - test_data.mu_true[i]) * (mu_hat * y_std + y_bar - test_data.mu_true[i]); - - // tau_preds from test_dataset_cate (z=1 basis) => raw CATE estimates - double cate_hat = 0.0; - for (int j = 0; j < num_mcmc; j++) - cate_hat += tau_preds[static_cast(j * n_test + i)] / num_mcmc; - tau_rmse_sum += (cate_hat * y_std - test_data.tau_true[i]) * (cate_hat * y_std - test_data.tau_true[i]); - - double y_hat = mu_hat * y_std + y_bar + cate_hat * test_data.z[i] * y_std; - y_rmse_sum += (y_hat - test_data.y[i]) * (y_hat - test_data.y[i]); - } - - std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" - << " mu RMSE (test): " << std::sqrt(mu_rmse_sum / n_test) << "\n" - << " tau RMSE (test): " << std::sqrt(tau_rmse_sum / n_test) << "\n" - << " y RMSE (test): " << std::sqrt(y_rmse_sum / n_test) << "\n" - << " sigma (last sample): " << std::sqrt(global_variance) * y_std << "\n" - << " sigma (truth): 0.5\n"; - }; - - // Dispatch BCF sampler - run_bcf_sampler(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, - dataset, residual, rng, test_dataset, post_iter, report); +// ---- Scenario 0: constant-leaf mu + univariate-leaf tau (identity link) --- + +static void run_scenario_0(int n, int n_test, int p, + int num_trees_mu, int num_trees_tau, + int num_gfr, int num_mcmc, int seed) { + std::mt19937 rng(seed < 0 ? std::random_device{}() : static_cast(seed)); + SimpleBCFDataset train = generate_simple_bcf_data(n, p, rng); + SimpleBCFDataset test = generate_simple_bcf_data(n_test, p, rng); + + StochTree::BCFData data; + data.X_train = train.X.data(); + data.y_train = train.y.data(); + data.z_train = train.z.data(); + data.n_train = n; + data.p_x = p; + data.treatment_dim = 1; + data.X_test = test.X.data(); + data.z_test = test.z.data(); + data.n_test = n_test; + + StochTree::BCFConfig config; + config.num_trees_mu = num_trees_mu; + config.num_trees_tau = num_trees_tau; + config.num_gfr = num_gfr; + config.num_mcmc = num_mcmc; + config.random_seed = seed; + config.link_function = StochTree::LinkFunction::Identity; + config.standardize_outcome = true; + config.sample_sigma2_global = true; + + StochTree::BCFSamples samples; + StochTree::BCFSamplerFit(&samples, config, data); + report_bcf(samples, test.mu_true, test.tau_true, test.y, + "Scenario 0 (BCF: constant mu + univariate tau, identity link)"); + std::cout << " sigma (truth): 0.5\n"; } -// ---- Scenario 1: constant-leaf mu + univariate-leaf tau (Z basis) with probit link --- - -void run_scenario_1(int n, int n_test, int p, int num_trees_mu, int num_trees_tau, int num_gfr, int num_mcmc, int seed = 42) { - // Allow seed to be non-deterministic if set to sentinel value of -1 - int rng_seed; - if (seed == -1) { - std::random_device rd; - rng_seed = rd(); - } else { - rng_seed = seed; - } - std::mt19937 rng(rng_seed); - - // Generate data and standardize outcome - ProbitBCFDataset data = generate_probit_bcf_data(n, p, rng); - double y_bar = std::accumulate(data.y.begin(), data.y.end(), 0.0) / data.y.size(); - std::vector Z_vec(n); - for (int i = 0; i < n; i++) { - Z_vec[i] = data.y[i] - y_bar; - } - std::vector y_vec = data.y; - - // Shared dataset: only tau forest uses the Z basis for leaf regression - StochTree::ForestDataset dataset; - dataset.AddCovariates(data.X.data(), n, p, /*row_major=*/true); - dataset.AddBasis(data.z.data(), n, /*num_col=*/1, /*row_major=*/false); - - // Shared residual - StochTree::ColumnVector residual(Z_vec.data(), n); - - // Global error variance model - constexpr double a_sigma = 0.0, b_sigma = 0.0; // non-informative IG prior - StochTree::GlobalHomoskedasticVarianceModel var_model; - - // Lambda function for probit data augmentation sampling step (after each forest sample) - auto post_iter = [&](double* outcome_preds, double&) { - StochTree::sample_probit_latent_outcome( - rng, y_vec.data(), outcome_preds, residual.GetData().data(), y_bar, n); - }; - - // Generate test data and build test datasets - ProbitBCFDataset test_data = generate_probit_bcf_data(n_test, p, rng); - - // Test dataset: covariates + actual treatment z (for y prediction) - StochTree::ForestDataset test_dataset; - test_dataset.AddCovariates(test_data.X.data(), n_test, p, /*row_major=*/true); - test_dataset.AddBasis(test_data.z.data(), n_test, /*num_col=*/1, /*row_major=*/false); - - // Lambda function for reporting mu/tau RMSE and last draw of global error variance - auto report = [&](const std::vector& mu_preds, const std::vector& tau_preds, - double global_variance) { - double mu_rmse_sum = 0.0, tau_rmse_sum = 0.0, y_rmse_sum = 0.0; - - for (int i = 0; i < n_test; i++) { - double mu_hat = 0.0; - for (int j = 0; j < num_mcmc; j++) - mu_hat += mu_preds[static_cast(j * n_test + i)] / num_mcmc; - mu_rmse_sum += (mu_hat + y_bar - test_data.mu_true[i]) * (mu_hat + y_bar - test_data.mu_true[i]); - - // tau_preds from test_dataset_cate (z=1 basis) => raw CATE estimates - double cate_hat = 0.0; - for (int j = 0; j < num_mcmc; j++) - cate_hat += tau_preds[static_cast(j * n_test + i)] / num_mcmc; - tau_rmse_sum += (cate_hat - test_data.tau_true[i]) * (cate_hat - test_data.tau_true[i]); - - double y_hat = mu_hat + y_bar + cate_hat * test_data.z[i]; - y_rmse_sum += (y_hat - test_data.latent_outcome[i]) * (y_hat - test_data.latent_outcome[i]); - } - - std::cout << "\nScenario 0 (BCF: constant mu + univariate tau with Z basis):\n" - << " mu RMSE (test): " << std::sqrt(mu_rmse_sum / n_test) << "\n" - << " tau RMSE (test): " << std::sqrt(tau_rmse_sum / n_test) << "\n" - << " latent outcome RMSE (test): " << std::sqrt(y_rmse_sum / n_test) << "\n" - << " sigma (last sample): " << std::sqrt(global_variance) << "\n" - << " sigma (truth): 1\n"; - }; - - // Dispatch BCF sampler - run_bcf_sampler(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, - dataset, residual, rng, test_dataset, post_iter, report); +// ---- Scenario 1: probit BCF (constant-leaf mu + univariate-leaf tau) ---- + +static void run_scenario_1(int n, int n_test, int p, + int num_trees_mu, int num_trees_tau, + int num_gfr, int num_mcmc, int seed) { + std::mt19937 rng(seed < 0 ? std::random_device{}() : static_cast(seed)); + ProbitBCFDataset train = generate_probit_bcf_data(n, p, rng); + ProbitBCFDataset test = generate_probit_bcf_data(n_test, p, rng); + + StochTree::BCFData data; + data.X_train = train.X.data(); + data.y_train = train.y.data(); + data.z_train = train.z.data(); + data.n_train = n; + data.p_x = p; + data.treatment_dim = 1; + data.X_test = test.X.data(); + data.z_test = test.z.data(); + data.n_test = n_test; + + StochTree::BCFConfig config; + config.num_trees_mu = num_trees_mu; + config.num_trees_tau = num_trees_tau; + config.num_gfr = num_gfr; + config.num_mcmc = num_mcmc; + config.random_seed = seed; + config.link_function = StochTree::LinkFunction::Probit; + config.standardize_outcome = true; + config.sample_sigma2_global = false; + + StochTree::BCFSamples samples; + StochTree::BCFSamplerFit(&samples, config, data); + // Predictions are on latent scale; compare tau to true CATE and y to latent W. + report_bcf(samples, test.mu_true, test.tau_true, test.latent_outcome, + "Scenario 1 (BCF: constant mu + univariate tau, probit link)"); } // ---- Main ----------------------------------------------------------- int main(int argc, char** argv) { - int scenario = 0; - int n = 500; - int n_test = 100; - int p = 5; - int num_trees_mu = 200; + int scenario = 0; + int n = 500; + int n_test = 100; + int p = 5; + int num_trees_mu = 200; int num_trees_tau = 50; - int num_gfr = 20; - int num_mcmc = 100; - int seed = 1234; + int num_gfr = 20; + int num_mcmc = 100; + int seed = 1234; for (int i = 1; i < argc; ++i) { std::string arg = argv[i]; if ((arg == "--scenario" || arg == "--n" || arg == "--n_test" || arg == "--p" || - arg == "--num_trees_mu" || arg == "--num_trees_tau" || arg == "--num_gfr" || arg == "--num_mcmc" || arg == "--seed") && - i + 1 < argc) { + arg == "--num_trees_mu" || arg == "--num_trees_tau" || arg == "--num_gfr" || + arg == "--num_mcmc" || arg == "--seed") && i + 1 < argc) { int val = std::stoi(argv[++i]); - if (arg == "--scenario") - scenario = val; - else if (arg == "--n") - n = val; - else if (arg == "--n_test") - n_test = val; - else if (arg == "--p") - p = val; - else if (arg == "--num_trees_mu") - num_trees_mu = val; - else if (arg == "--num_trees_tau") - num_trees_tau = val; - else if (arg == "--num_gfr") - num_gfr = val; - else if (arg == "--num_mcmc") - num_mcmc = val; - else if (arg == "--seed") - seed = val; + if (arg == "--scenario") scenario = val; + else if (arg == "--n") n = val; + else if (arg == "--n_test") n_test = val; + else if (arg == "--p") p = val; + else if (arg == "--num_trees_mu") num_trees_mu = val; + else if (arg == "--num_trees_tau") num_trees_tau = val; + else if (arg == "--num_gfr") num_gfr = val; + else if (arg == "--num_mcmc") num_mcmc = val; + else if (arg == "--seed") seed = val; } else { std::cerr << "Unknown or incomplete argument: " << arg << "\n" << "Usage: bcf_debug [--scenario N] [--n N] [--n_test N] [--p N]" @@ -455,15 +252,11 @@ int main(int argc, char** argv) { } switch (scenario) { - case 0: - run_scenario_0(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); - break; - case 1: - run_scenario_1(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); - break; + case 0: run_scenario_0(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); break; + case 1: run_scenario_1(n, n_test, p, num_trees_mu, num_trees_tau, num_gfr, num_mcmc, seed); break; default: std::cerr << "Unknown scenario " << scenario - << ". Available scenarios: 0 (BCF: constant mu + univariate tau)\n"; + << ". Available: 0 (BCF: identity), 1 (BCF: probit)\n"; return 1; } return 0; diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index c640b8a9..c93a1669 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -8,58 +8,63 @@ #include #include #include "stochtree/container.h" +#include "stochtree/meta.h" namespace StochTree { struct BARTData { // Train set covariates - double* X_train; + double* X_train = nullptr; int n_train = 0; int p = 0; // Test set covariates - double* X_test; + double* X_test = nullptr; int n_test = 0; // Train set outcome - double* y_train; + double* y_train = nullptr; // Basis for leaf regression - double* basis_train; - double* basis_test; + double* basis_train = nullptr; + double* basis_test = nullptr; int basis_dim = 0; // Observation weights - double* obs_weights_train; - double* obs_weights_test; + double* obs_weights_train = nullptr; + double* obs_weights_test = nullptr; // Random effects - int* rfx_group_ids_train; - int* rfx_group_ids_test; - double* rfx_basis_train; - double* rfx_basis_test; + int* rfx_group_ids_train = nullptr; + int* rfx_group_ids_test = nullptr; + double* rfx_basis_train = nullptr; + double* rfx_basis_test = nullptr; int rfx_num_groups = 0; int rfx_basis_dim = 0; }; struct BARTConfig { // High level parameters - bool standardize_outcome = true; // whether to standardize the outcome before fitting and unstandardize predictions after - int num_threads = 1; // number of threads to use for sampling - int cutpoint_grid_size = 100; // number of cutpoints to consider for each covariate when sampling splits - std::vector feature_types; // feature types for each covariate (should be same length as number of covariates in the dataset), where 0 = continuous, 1 = categorical - std::vector sweep_update_indices; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) + bool standardize_outcome = true; // whether to standardize the outcome before fitting and unstandardize predictions after + int num_threads = 1; // number of threads to use for sampling + int cutpoint_grid_size = 100; // number of cutpoints to consider for each covariate when sampling splits + std::vector feature_types; // feature types for each covariate (should be same length as number of covariates in the dataset), where 0 = continuous, 1 = categorical + std::vector sweep_update_indices; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) // Global error variance parameters - double a_sigma2_global = 0.0; // shape parameter for inverse gamma prior on global error variance - double b_sigma2_global = 0.0; // scale parameter for inverse gamma prior on global error variance - double sigma2_global_init = 1.0; // initial value for global error variance + double a_sigma2_global = 0.0; // shape parameter for inverse gamma prior on global error variance + double b_sigma2_global = 0.0; // scale parameter for inverse gamma prior on global error variance + double sigma2_global_init = 1.0; // initial value for global error variance + bool probit = false; // whether to use probit link (if true, global error variance is not sampled and latent outcomes are sampled instead) + int random_seed = -1; // random seed for reproducibility (if negative, a random seed will be generated) + bool sample_sigma2_global = true; // whether to sample global error variance (if false, it will be fixed at sigma2_global_init) // Mean forest parameters int num_trees_mean = 200; // number of trees in the mean forest double alpha_mean = 0.95; // alpha parameter for mean forest tree prior double beta_mean = 2.0; // beta parameter for mean forest tree prior int min_samples_leaf_mean = 5; // minimum number of samples per leaf for mean forest + int max_depth_mean = -1; // maximum depth for mean forest trees (-1 means no maximum) bool leaf_constant_mean = true; // whether to use constant leaf model for mean forest int leaf_dim_mean = 1; // dimension of the leaf for mean forest bool exponentiated_leaf_mean = false; // whether to exponentiate leaf predictions for mean forest @@ -68,12 +73,14 @@ struct BARTConfig { double b_sigma2_mean = -1.0; // scale parameter for inverse gamma prior on mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) double sigma2_mean_init = -1.0; // initial value of mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) std::vector var_weights_mean; // variable weights for mean forest splits (should be same length as number of covariates in the dataset) + bool sample_sigma2_leaf_mean = true; // whether to sample mean forest leaf scale (if false, it will be fixed at sigma2_mean_init) // Variance forest parameters int num_trees_variance = 0; // number of trees in the variance forest double alpha_variance = 0.5; // alpha parameter for variance forest tree prior double beta_variance = 2.0; // beta parameter for variance forest tree prior int min_samples_leaf_variance = 5; // minimum number of samples per leaf for variance forest + int max_depth_variance = -1; // maximum depth for variance forest trees (-1 means no maximum) bool leaf_constant_variance = true; // whether to use constant leaf model for variance forest int leaf_dim_variance = 1; // dimension of the leaf for variance forest (should be 1 if leaf_constant_variance=true) bool exponentiated_leaf_variance = true; // whether to exponentiate leaf predictions for variance forest @@ -111,6 +118,13 @@ struct BARTSamples { std::unique_ptr variance_forests; // TODO: Pointer to random effects samples ... + + // Metadata about the samples (e.g., number of samples, burn-in, etc.) could be added here as needed + int num_samples = 0; + int num_train = 0; + int num_test = 0; + double y_bar = 0.0; + double y_std = 0.0; }; } // namespace StochTree diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h new file mode 100644 index 00000000..92077c12 --- /dev/null +++ b/include/stochtree/bart_sampler.h @@ -0,0 +1,80 @@ +/*! + * Copyright (c) 2026 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_BART_SAMPLER_H_ +#define STOCHTREE_BART_SAMPLER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "stochtree/prior.h" + +namespace StochTree { + +class BARTSampler { + public: + BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data); + + // Main entry point for running the BART sampler, which dispatches to GFR warmup and MCMC sampling functions + void run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_gfr, bool keep_gfr); + + // Main entry point for running the BART sampler, which dispatches to GFR warmup and MCMC sampling functions + void run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_burnin, int keep_every, int num_mcmc); + + private: + /*! Initialize state variables */ + void InitializeState(BARTSamples& samples, BARTConfig& config, BARTData& data); + bool initialized_ = false; + + /*! Mean forest state */ + std::unique_ptr mean_forest_; + std::unique_ptr mean_forest_tracker_; + std::unique_ptr tree_prior_mean_; + bool has_mean_forest_ = false; + + /*! Variance forest state */ + std::unique_ptr variance_forest_; + std::unique_ptr variance_forest_tracker_; + std::unique_ptr tree_prior_variance_; + bool has_variance_forest_ = false; + + /*! Dataset */ + std::unique_ptr residual_; + std::unique_ptr outcome_raw_; + std::unique_ptr forest_dataset_; + std::unique_ptr forest_dataset_test_; + bool has_test_ = false; + + /*! Random number generator */ + std::mt19937 rng_; + + /*! Model parameters */ + double global_variance_; + double leaf_scale_; + std::vector leaf_scale_multivariate_; + + // Global error scale model + std::unique_ptr var_model_; + bool sample_sigma2_global_ = false; + + // Leaf scale model + std::unique_ptr leaf_scale_model_; + bool sample_sigma2_leaf_ = false; + + /*! Random effects state */ + // TODO ... + + /*! Vector of warm-start snapshots (forests needed for MCMC chains but not retained) */ + std::vector warm_start_forests_mean_; + std::vector warm_start_forests_variance_; +}; + +} // namespace StochTree + +#endif // STOCHTREE_BART_SAMPLER_H_ diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp new file mode 100644 index 00000000..a85af4ce --- /dev/null +++ b/src/bart_sampler.cpp @@ -0,0 +1,181 @@ +/*! Copyright (c) 2026 by stochtree authors */ +#include +#include +#include +#include +#include +#include +#include + +namespace StochTree { + +BARTSampler::BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data) { + InitializeState(samples, config, data); +} + +void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BARTData& data) { + // Load data from BARTData object into ForestDataset object + forest_dataset_ = std::make_unique(); + forest_dataset_->AddCovariates(data.X_train, data.n_train, data.p, /*row_major=*/false); + if (data.basis_train != nullptr) { + forest_dataset_->AddBasis(data.basis_train, data.n_train, data.basis_dim, /*row_major=*/false); + } + if (data.obs_weights_train != nullptr) { + forest_dataset_->AddVarianceWeights(data.obs_weights_train, data.n_train); + } + samples.num_train = data.n_train; + samples.num_test = data.n_test; + residual_ = std::make_unique(data.y_train, data.n_train); + outcome_raw_ = std::make_unique(data.y_train, data.n_train); + if (data.X_test != nullptr) { + forest_dataset_test_ = std::make_unique(); + forest_dataset_test_->AddCovariates(data.X_test, data.n_test, data.p, /*row_major=*/false); + if (data.basis_test != nullptr) { + forest_dataset_test_->AddBasis(data.basis_test, data.n_test, data.basis_dim, /*row_major=*/false); + } + if (data.obs_weights_test != nullptr) { + forest_dataset_test_->AddVarianceWeights(data.obs_weights_test, data.n_test); + } + has_test_ = true; + } + + // Compute outcome location and scale for standardization + samples.y_bar = 0.0; + samples.y_std = 0.0; + for (int i = 0; i < data.n_train; i++) samples.y_bar += data.y_train[i]; + samples.y_bar /= data.n_train; + for (int i = 0; i < data.n_train; i++) samples.y_std += (data.y_train[i] - samples.y_bar) * (data.y_train[i] - samples.y_bar); + samples.y_std = std::sqrt(samples.y_std / data.n_train); + + // Standardize partial residuals in place; these are updated in each iteration but initialized to standardized outcomes + for (int i = 0; i < data.n_train; i++) residual_->GetData()[i] = (data.y_train[i] - samples.y_bar) / samples.y_std; + + // Initialize mean forest state (if present) + if (config.num_trees_mean > 0) { + mean_forest_ = std::make_unique(config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean); + samples.mean_forests = std::make_unique(config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean); + mean_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config.feature_types, config.num_trees_mean, data.n_train); + tree_prior_mean_ = std::make_unique(config.alpha_mean, config.beta_mean, config.min_samples_leaf_mean, config.max_depth_mean); + mean_forest_->SetLeafValue(0.0); + UpdateResidualEntireForest(*mean_forest_tracker_, *forest_dataset_, *residual_, mean_forest_.get(), !config.leaf_constant_mean, std::minus()); + mean_forest_tracker_->UpdatePredictions(mean_forest_.get(), *forest_dataset_.get()); + has_mean_forest_ = true; + if (config.sigma2_mean_init < 0.0) { + config.sigma2_mean_init = (samples.y_std * samples.y_std) / config.num_trees_mean; + } + } + + // Initialize variance forest state (if present) + if (config.num_trees_variance > 0) { + variance_forest_ = std::make_unique(config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance); + samples.variance_forests = std::make_unique(config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance); + variance_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config.feature_types, config.num_trees_variance, data.n_train); + tree_prior_variance_ = std::make_unique(config.alpha_variance, config.beta_variance, config.min_samples_leaf_variance, config.max_depth_variance); + variance_forest_->SetLeafValue(1.0 / config.num_trees_variance); + variance_forest_tracker_->UpdatePredictions(variance_forest_.get(), *forest_dataset_.get()); + has_variance_forest_ = true; + } + + // Global error variance model + if (config.sample_sigma2_global) { + var_model_ = std::make_unique(); + sample_sigma2_global_ = true; + } + + // Leaf scale model + if (config.sample_sigma2_leaf_mean) { + leaf_scale_model_ = std::make_unique(); + sample_sigma2_leaf_ = true; + } + + // RNG + rng_ = std::mt19937(config.random_seed >= 0 ? config.random_seed : std::random_device{}()); + + // Other internal model state + global_variance_ = config.sigma2_global_init; + leaf_scale_ = config.sigma2_mean_init; + // leaf_scale_multivariate_ = config.sigma2_leaf_multivariate_init; +} + +void BARTSampler::run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_gfr, bool keep_gfr) { + // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance + for (int i = 0; i < num_gfr; i++) { + if (has_mean_forest_) { + GaussianConstantLeafModel leaf_model(leaf_scale_); + GFRSampleOneIter( + *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, leaf_model, + *forest_dataset_, *residual_, *tree_prior_mean_, rng, + config.var_weights_mean, config.sweep_update_indices, global_variance_, config.feature_types, + config.cutpoint_grid_size, /*keep_forest=*/keep_gfr, + /*pre_initialized=*/true, /*backfitting=*/true, + /*num_features_subsample=*/data.p, config.num_threads); + } + + if (config.probit) { + sample_probit_latent_outcome(rng_, outcome_raw_->GetData().data(), mean_forest_tracker_->GetSumPredictions(), + residual_->GetData().data(), samples.y_bar, data.n_train); + } + + if (sample_sigma2_global_) { + global_variance_ = var_model_->SampleVarianceParameter( + residual_->GetData(), config.a_sigma2_global, config.b_sigma2_global, rng_); + } + + if (keep_gfr) { + samples.num_samples++; + if (sample_sigma2_global_) samples.global_error_variance_samples.push_back(global_variance_); + if (has_mean_forest_) { + samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), + mean_forest_tracker_->GetSumPredictions(), mean_forest_tracker_->GetSumPredictions() + samples.num_train); + int num_samples = samples.mean_forests->NumSamples(); + std::vector predictions = samples.mean_forests->GetEnsemble(num_samples - 1)->Predict(*forest_dataset_test_); + samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), + predictions.data(), predictions.data() + samples.num_test); + } + } + } +} + +void BARTSampler::run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_burnin, int keep_every, int num_mcmc) { + // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance + bool keep_forest = false; + for (int i = 0; i < num_burnin + keep_every * num_mcmc; i++) { + if (i >= num_burnin && (i - num_burnin) % keep_every == 0) { + keep_forest = true; + } else { + keep_forest = false; + } + if (has_mean_forest_) { + GaussianConstantLeafModel leaf_model(leaf_scale_); + MCMCSampleOneIter( + *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, leaf_model, + *forest_dataset_, *residual_, *tree_prior_mean_, rng, + config.var_weights_mean, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_forest, + /*pre_initialized=*/true, /*backfitting=*/true, + /*num_threads=*/config.num_threads); + } + + if (config.probit) { + sample_probit_latent_outcome(rng_, outcome_raw_->GetData().data(), mean_forest_tracker_->GetSumPredictions(), + residual_->GetData().data(), samples.y_bar, data.n_train); + } + + if (sample_sigma2_global_) { + global_variance_ = var_model_->SampleVarianceParameter( + residual_->GetData(), config.a_sigma2_global, config.b_sigma2_global, rng_); + } + + if (keep_forest) { + samples.num_samples++; + if (sample_sigma2_global_) samples.global_error_variance_samples.push_back(global_variance_); + if (has_mean_forest_) { + samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), + mean_forest_tracker_->GetSumPredictions(), mean_forest_tracker_->GetSumPredictions() + samples.num_train); + int num_samples = samples.mean_forests->NumSamples(); + samples.mean_forests->GetEnsemble(num_samples - 1)->PredictInplace(*forest_dataset_test_, samples.mean_forest_predictions_test, (num_samples - 1) * samples.num_test); + } + } + } +} + +} // namespace StochTree From 2c05a328fd11f58482b2d0943f2f18053dcb3445 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 10 Apr 2026 09:52:15 -0400 Subject: [PATCH 23/64] Working implementation of BARTSampler --- .vscode/launch.json | 16 ++++++++-------- debug/bart_debug.cpp | 10 +++++----- src/bart_sampler.cpp | 24 +++++++++++++++++------- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index fa7e941c..1f46e899 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,7 +6,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["--scenario", "0", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (dev-quick)" }, @@ -15,7 +15,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build/bcf_debug", - "args": ["--scenario", "0", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (dev-quick)" }, @@ -33,7 +33,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build/bart_debug", - "args": ["--scenario", "0", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev-quick)" @@ -43,7 +43,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build/bcf_debug", - "args": ["--scenario", "0", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (dev-quick)" @@ -63,7 +63,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bart_debug", - "args": ["--scenario", "0", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (release-drivers)" }, @@ -72,7 +72,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bart_debug", - "args": ["--scenario", "0", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "500", "--n_test", "100", "--p", "5", "--num_trees", "200", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (release-drivers)" @@ -82,7 +82,7 @@ "type": "lldb", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bcf_debug", - "args": ["--scenario", "0", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: Build (release-drivers)" }, @@ -91,7 +91,7 @@ "type": "cppdbg", "request": "launch", "program": "${workspaceFolder}/build-release-drivers/bcf_debug", - "args": ["--scenario", "0", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], + "args": ["--scenario", "1", "--n", "200", "--n_test", "100", "--p", "5", "--num_trees_mu", "200", "--num_trees_tau", "50", "--num_gfr", "10", "--num_mcmc", "100", "--seed", "-1"], "cwd": "${workspaceFolder}", "MIMode": "gdb", "preLaunchTask": "CMake: Build (release-drivers)" diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 2251fda4..545995de 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -88,17 +88,17 @@ static void report_bart(const StochTree::BARTSamples& samples, const int n_test = samples.num_test; double rmse_sum = 0.0; for (int i = 0; i < n_test; i++) { - double mu_hat = 0.0; + double y_hat = 0.0; for (int j = 0; j < num_samples; j++) - mu_hat += samples.mean_forest_predictions_test[static_cast(j * n_test + i)] / num_samples; - double err = mu_hat - test_ref[i]; + y_hat += samples.mean_forest_predictions_test[static_cast(j * n_test + i)] / num_samples; + double err = (y_hat * samples.y_std + samples.y_bar) - test_ref[i]; rmse_sum += err * err; } std::cout << "\n" << scenario_name << ":\n" << " RMSE (test): " << std::sqrt(rmse_sum / n_test) << "\n"; if (!samples.global_error_variance_samples.empty()) { - std::cout << " sigma (last): " << std::sqrt(samples.global_error_variance_samples.back()) << "\n" + std::cout << " sigma (last): " << std::sqrt(samples.global_error_variance_samples.back()) * samples.y_std << "\n" << " sigma (truth): 1.0\n"; } } @@ -123,7 +123,7 @@ static void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, config.random_seed = seed; config.probit = false; config.standardize_outcome = true; - config.sample_sigma2_global = false; + config.sample_sigma2_global = true; config.var_weights_mean = std::vector(p, 1.0 / p); config.feature_types = std::vector(p, StochTree::FeatureType::kNumeric); config.sweep_update_indices = std::vector(num_trees, 0); diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index a85af4ce..c6048848 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -1,10 +1,12 @@ /*! Copyright (c) 2026 by stochtree authors */ #include #include +#include #include #include #include #include +#include #include namespace StochTree { @@ -40,12 +42,18 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART } // Compute outcome location and scale for standardization - samples.y_bar = 0.0; - samples.y_std = 0.0; - for (int i = 0; i < data.n_train; i++) samples.y_bar += data.y_train[i]; - samples.y_bar /= data.n_train; - for (int i = 0; i < data.n_train; i++) samples.y_std += (data.y_train[i] - samples.y_bar) * (data.y_train[i] - samples.y_bar); - samples.y_std = std::sqrt(samples.y_std / data.n_train); + if (config.probit) { + samples.y_std = 1.0; + double y_mean = std::accumulate(data.y_train, data.y_train + data.n_train, 0.0) / data.n_train; + samples.y_bar = norm_cdf(y_mean); + } else { + samples.y_bar = 0.0; + samples.y_std = 0.0; + for (int i = 0; i < data.n_train; i++) samples.y_bar += data.y_train[i]; + samples.y_bar /= data.n_train; + for (int i = 0; i < data.n_train; i++) samples.y_std += (data.y_train[i] - samples.y_bar) * (data.y_train[i] - samples.y_bar); + samples.y_std = std::sqrt(samples.y_std / data.n_train); + } // Standardize partial residuals in place; these are updated in each iteration but initialized to standardized outcomes for (int i = 0; i < data.n_train; i++) residual_->GetData()[i] = (data.y_train[i] - samples.y_bar) / samples.y_std; @@ -172,7 +180,9 @@ void BARTSampler::run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& d samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), mean_forest_tracker_->GetSumPredictions(), mean_forest_tracker_->GetSumPredictions() + samples.num_train); int num_samples = samples.mean_forests->NumSamples(); - samples.mean_forests->GetEnsemble(num_samples - 1)->PredictInplace(*forest_dataset_test_, samples.mean_forest_predictions_test, (num_samples - 1) * samples.num_test); + std::vector predictions = samples.mean_forests->GetEnsemble(num_samples - 1)->Predict(*forest_dataset_test_); + samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), + predictions.data(), predictions.data() + samples.num_test); } } } From 65d4d540d78dbf4b6a84b9a16bd4c8be3bfe298b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 10 Apr 2026 10:23:36 -0400 Subject: [PATCH 24/64] Updated sampler and debug program --- debug/bart_debug.cpp | 5 +- include/stochtree/bart.h | 31 +++++--- include/stochtree/bart_sampler.h | 6 +- src/bart_sampler.cpp | 124 +++++++++++++++---------------- 4 files changed, 87 insertions(+), 79 deletions(-) diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 545995de..6772aec1 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -121,7 +121,7 @@ static void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, StochTree::BARTConfig config; config.num_trees_mean = num_trees; config.random_seed = seed; - config.probit = false; + config.link_function = StochTree::LinkFunction::Identity; config.standardize_outcome = true; config.sample_sigma2_global = true; config.var_weights_mean = std::vector(p, 1.0 / p); @@ -154,8 +154,7 @@ static void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, StochTree::BARTConfig config; config.num_trees_mean = num_trees; config.random_seed = seed; - config.probit = true; - config.standardize_outcome = true; + config.link_function = StochTree::LinkFunction::Probit; config.sample_sigma2_global = false; config.var_weights_mean = std::vector(p, 1.0 / p); config.feature_types = std::vector(p, StochTree::FeatureType::kNumeric); diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index c93a1669..addb6428 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -7,11 +7,23 @@ #include #include -#include "stochtree/container.h" -#include "stochtree/meta.h" +#include +#include namespace StochTree { +enum class LinkFunction { + Identity, + Probit, + Cloglog +}; + +enum class OutcomeType { + Continuous, + Binary, + Ordinal +}; + struct BARTData { // Train set covariates double* X_train = nullptr; @@ -45,18 +57,19 @@ struct BARTData { struct BARTConfig { // High level parameters - bool standardize_outcome = true; // whether to standardize the outcome before fitting and unstandardize predictions after - int num_threads = 1; // number of threads to use for sampling - int cutpoint_grid_size = 100; // number of cutpoints to consider for each covariate when sampling splits - std::vector feature_types; // feature types for each covariate (should be same length as number of covariates in the dataset), where 0 = continuous, 1 = categorical - std::vector sweep_update_indices; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) + bool standardize_outcome = true; // whether to standardize the outcome before fitting and unstandardize predictions after + int num_threads = 1; // number of threads to use for sampling + int cutpoint_grid_size = 100; // number of cutpoints to consider for each covariate when sampling splits + std::vector feature_types; // feature types for each covariate (should be same length as number of covariates in the dataset), where 0 = continuous, 1 = categorical + std::vector sweep_update_indices; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) + LinkFunction link_function = LinkFunction::Identity; // link function to use (Identity, Probit, Cloglog) + OutcomeType outcome_type = OutcomeType::Continuous; // type of the outcome variable (Continuous, Binary, Ordinal) + int random_seed = -1; // random seed for reproducibility (if negative, a random seed will be generated) // Global error variance parameters double a_sigma2_global = 0.0; // shape parameter for inverse gamma prior on global error variance double b_sigma2_global = 0.0; // scale parameter for inverse gamma prior on global error variance double sigma2_global_init = 1.0; // initial value for global error variance - bool probit = false; // whether to use probit link (if true, global error variance is not sampled and latent outcomes are sampled instead) - int random_seed = -1; // random seed for reproducibility (if negative, a random seed will be generated) bool sample_sigma2_global = true; // whether to sample global error variance (if false, it will be fixed at sigma2_global_init) // Mean forest parameters diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h index 92077c12..74293daa 100644 --- a/include/stochtree/bart_sampler.h +++ b/include/stochtree/bart_sampler.h @@ -9,11 +9,12 @@ #include #include #include +#include #include +#include #include #include #include -#include "stochtree/prior.h" namespace StochTree { @@ -32,6 +33,9 @@ class BARTSampler { void InitializeState(BARTSamples& samples, BARTConfig& config, BARTData& data); bool initialized_ = false; + /*! Internal sample runner function */ + void RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel& leaf_model, std::mt19937& rng, bool gfr, bool keep_sample); + /*! Mean forest state */ std::unique_ptr mean_forest_; std::unique_ptr mean_forest_tracker_; diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index c6048848..de450e50 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -42,17 +42,22 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART } // Compute outcome location and scale for standardization - if (config.probit) { + if (config.link_function == LinkFunction::Probit) { samples.y_std = 1.0; double y_mean = std::accumulate(data.y_train, data.y_train + data.n_train, 0.0) / data.n_train; samples.y_bar = norm_cdf(y_mean); } else { - samples.y_bar = 0.0; - samples.y_std = 0.0; - for (int i = 0; i < data.n_train; i++) samples.y_bar += data.y_train[i]; - samples.y_bar /= data.n_train; - for (int i = 0; i < data.n_train; i++) samples.y_std += (data.y_train[i] - samples.y_bar) * (data.y_train[i] - samples.y_bar); - samples.y_std = std::sqrt(samples.y_std / data.n_train); + if (config.standardize_outcome) { + samples.y_bar = 0.0; + samples.y_std = 0.0; + for (int i = 0; i < data.n_train; i++) samples.y_bar += data.y_train[i]; + samples.y_bar /= data.n_train; + for (int i = 0; i < data.n_train; i++) samples.y_std += (data.y_train[i] - samples.y_bar) * (data.y_train[i] - samples.y_bar); + samples.y_std = std::sqrt(samples.y_std / data.n_train); + } else { + samples.y_bar = 0.0; + samples.y_std = 1.0; + } } // Standardize partial residuals in place; these are updated in each iteration but initialized to standardized outcomes @@ -107,83 +112,70 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART void BARTSampler::run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_gfr, bool keep_gfr) { // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance + GaussianConstantLeafModel leaf_model(leaf_scale_); for (int i = 0; i < num_gfr; i++) { - if (has_mean_forest_) { - GaussianConstantLeafModel leaf_model(leaf_scale_); - GFRSampleOneIter( - *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, leaf_model, - *forest_dataset_, *residual_, *tree_prior_mean_, rng, - config.var_weights_mean, config.sweep_update_indices, global_variance_, config.feature_types, - config.cutpoint_grid_size, /*keep_forest=*/keep_gfr, - /*pre_initialized=*/true, /*backfitting=*/true, - /*num_features_subsample=*/data.p, config.num_threads); - } - - if (config.probit) { - sample_probit_latent_outcome(rng_, outcome_raw_->GetData().data(), mean_forest_tracker_->GetSumPredictions(), - residual_->GetData().data(), samples.y_bar, data.n_train); - } - - if (sample_sigma2_global_) { - global_variance_ = var_model_->SampleVarianceParameter( - residual_->GetData(), config.a_sigma2_global, config.b_sigma2_global, rng_); - } - - if (keep_gfr) { - samples.num_samples++; - if (sample_sigma2_global_) samples.global_error_variance_samples.push_back(global_variance_); - if (has_mean_forest_) { - samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), - mean_forest_tracker_->GetSumPredictions(), mean_forest_tracker_->GetSumPredictions() + samples.num_train); - int num_samples = samples.mean_forests->NumSamples(); - std::vector predictions = samples.mean_forests->GetEnsemble(num_samples - 1)->Predict(*forest_dataset_test_); - samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), - predictions.data(), predictions.data() + samples.num_test); - } - } + RunOneIteration(samples, config, data, leaf_model, rng, /*gfr=*/true, /*keep_sample=*/keep_gfr); } } void BARTSampler::run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_burnin, int keep_every, int num_mcmc) { - // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance + GaussianConstantLeafModel leaf_model(leaf_scale_); bool keep_forest = false; for (int i = 0; i < num_burnin + keep_every * num_mcmc; i++) { - if (i >= num_burnin && (i - num_burnin) % keep_every == 0) { + if (i >= num_burnin && (i - num_burnin) % keep_every == 0) keep_forest = true; - } else { + else keep_forest = false; - } - if (has_mean_forest_) { - GaussianConstantLeafModel leaf_model(leaf_scale_); + RunOneIteration(samples, config, data, leaf_model, rng, /*gfr=*/false, /*keep_sample=*/keep_forest); + } +} + +void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel& leaf_model, std::mt19937& rng, bool gfr, bool keep_sample) { + if (has_mean_forest_) { + if (gfr) { + GFRSampleOneIter( + *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, leaf_model, + *forest_dataset_, *residual_, *tree_prior_mean_, rng, + config.var_weights_mean, config.sweep_update_indices, global_variance_, config.feature_types, + config.cutpoint_grid_size, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/true, + /*num_features_subsample=*/data.p, config.num_threads); + } else { MCMCSampleOneIter( *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, leaf_model, *forest_dataset_, *residual_, *tree_prior_mean_, rng, - config.var_weights_mean, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_forest, + config.var_weights_mean, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, /*num_threads=*/config.num_threads); } + } - if (config.probit) { - sample_probit_latent_outcome(rng_, outcome_raw_->GetData().data(), mean_forest_tracker_->GetSumPredictions(), - residual_->GetData().data(), samples.y_bar, data.n_train); - } + if (config.link_function == LinkFunction::Probit) { + sample_probit_latent_outcome(rng_, outcome_raw_->GetData().data(), mean_forest_tracker_->GetSumPredictions(), + residual_->GetData().data(), samples.y_bar, data.n_train); + } - if (sample_sigma2_global_) { - global_variance_ = var_model_->SampleVarianceParameter( - residual_->GetData(), config.a_sigma2_global, config.b_sigma2_global, rng_); - } + if (sample_sigma2_global_) { + global_variance_ = var_model_->SampleVarianceParameter( + residual_->GetData(), config.a_sigma2_global, config.b_sigma2_global, rng_); + } + + if (sample_sigma2_leaf_) { + leaf_scale_ = leaf_scale_model_->SampleVarianceParameter( + mean_forest_.get(), config.a_sigma2_mean, config.b_sigma2_mean, rng_); + } - if (keep_forest) { - samples.num_samples++; - if (sample_sigma2_global_) samples.global_error_variance_samples.push_back(global_variance_); - if (has_mean_forest_) { - samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), - mean_forest_tracker_->GetSumPredictions(), mean_forest_tracker_->GetSumPredictions() + samples.num_train); - int num_samples = samples.mean_forests->NumSamples(); - std::vector predictions = samples.mean_forests->GetEnsemble(num_samples - 1)->Predict(*forest_dataset_test_); - samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), - predictions.data(), predictions.data() + samples.num_test); - } + if (keep_sample) { + samples.num_samples++; + if (sample_sigma2_global_) samples.global_error_variance_samples.push_back(global_variance_); + if (sample_sigma2_leaf_) samples.leaf_scale_samples.push_back(leaf_scale_); + if (has_mean_forest_) { + double* mean_forest_preds_train = mean_forest_tracker_->GetSumPredictions(); + samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), + mean_forest_preds_train, mean_forest_preds_train + samples.num_train); + std::vector predictions = samples.mean_forests->GetEnsemble(samples.num_samples - 1)->Predict(*forest_dataset_test_); + samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), + predictions.data(), predictions.data() + samples.num_test); } } } From 35c8f6f53a5c83bca3f2d871b1aaa5ff836fb125 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 10 Apr 2026 11:17:29 -0400 Subject: [PATCH 25/64] Updated probit and leaf scale initialization --- include/stochtree/bart.h | 25 +++++---- include/stochtree/bart_sampler.h | 2 +- src/bart_sampler.cpp | 94 +++++++++++++++++++++++++------- 3 files changed, 90 insertions(+), 31 deletions(-) diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index addb6428..5548bef2 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -86,19 +86,22 @@ struct BARTConfig { double b_sigma2_mean = -1.0; // scale parameter for inverse gamma prior on mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) double sigma2_mean_init = -1.0; // initial value of mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) std::vector var_weights_mean; // variable weights for mean forest splits (should be same length as number of covariates in the dataset) - bool sample_sigma2_leaf_mean = true; // whether to sample mean forest leaf scale (if false, it will be fixed at sigma2_mean_init) + bool sample_sigma2_leaf_mean = false; // whether to sample mean forest leaf scale (if false, it will be fixed at sigma2_mean_init) // Variance forest parameters - int num_trees_variance = 0; // number of trees in the variance forest - double alpha_variance = 0.5; // alpha parameter for variance forest tree prior - double beta_variance = 2.0; // beta parameter for variance forest tree prior - int min_samples_leaf_variance = 5; // minimum number of samples per leaf for variance forest - int max_depth_variance = -1; // maximum depth for variance forest trees (-1 means no maximum) - bool leaf_constant_variance = true; // whether to use constant leaf model for variance forest - int leaf_dim_variance = 1; // dimension of the leaf for variance forest (should be 1 if leaf_constant_variance=true) - bool exponentiated_leaf_variance = true; // whether to exponentiate leaf predictions for variance forest - int num_features_subsample_variance = 0; // number of features to subsample for each variance forest split (0 means no subsampling) - std::vector var_weights_variance; // variable weights for variance forest splits (should be same length as number of covariates in the dataset) + int num_trees_variance = 0; // number of trees in the variance forest + double leaf_prior_calibration_param = 1.5; // calibration parameter for variance forest leaf prior + double shape_variance_forest = -1.0; // shape parameter for variance forest leaf model (calibrated internally based on leaf_prior_calibration_param if set to sentinel value of -1) + double scale_variance_forest = -1.0; // scale parameter for variance forest leaf model (calibrated internally based on leaf_prior_calibration_param if set to sentinel value of -1) + double alpha_variance = 0.5; // alpha parameter for variance forest tree prior + double beta_variance = 2.0; // beta parameter for variance forest tree prior + int min_samples_leaf_variance = 5; // minimum number of samples per leaf for variance forest + int max_depth_variance = -1; // maximum depth for variance forest trees (-1 means no maximum) + bool leaf_constant_variance = true; // whether to use constant leaf model for variance forest + int leaf_dim_variance = 1; // dimension of the leaf for variance forest (should be 1 if leaf_constant_variance=true) + bool exponentiated_leaf_variance = true; // whether to exponentiate leaf predictions for variance forest + int num_features_subsample_variance = 0; // number of features to subsample for each variance forest split (0 means no subsampling) + std::vector var_weights_variance; // variable weights for variance forest splits (should be same length as number of covariates in the dataset) // TODO: Random effects parameters ... diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h index 74293daa..fb1b96bd 100644 --- a/include/stochtree/bart_sampler.h +++ b/include/stochtree/bart_sampler.h @@ -34,7 +34,7 @@ class BARTSampler { bool initialized_ = false; /*! Internal sample runner function */ - void RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel& leaf_model, std::mt19937& rng, bool gfr, bool keep_sample); + void RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, std::mt19937& rng, bool gfr, bool keep_sample); /*! Mean forest state */ std::unique_ptr mean_forest_; diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index de450e50..8992fc6b 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -6,8 +6,10 @@ #include #include #include +#include #include #include +#include "stochtree/leaf_model.h" namespace StochTree { @@ -41,19 +43,23 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART has_test_ = true; } + // Precompute outcome mean and variance for standardization and calibration + double y_mean = 0.0, M2 = 0.0, y_mean_prev = 0.0; + for (int i = 0; i < data.n_train; i++) { + y_mean_prev = y_mean; + y_mean = y_mean_prev + (data.y_train[i] - y_mean_prev) / (i + 1); + M2 = M2 + (data.y_train[i] - y_mean_prev) * (data.y_train[i] - y_mean); + } + double y_var = M2 / data.n_train; + // Compute outcome location and scale for standardization if (config.link_function == LinkFunction::Probit) { samples.y_std = 1.0; - double y_mean = std::accumulate(data.y_train, data.y_train + data.n_train, 0.0) / data.n_train; - samples.y_bar = norm_cdf(y_mean); + samples.y_bar = norm_inv_cdf(y_mean); } else { if (config.standardize_outcome) { - samples.y_bar = 0.0; - samples.y_std = 0.0; - for (int i = 0; i < data.n_train; i++) samples.y_bar += data.y_train[i]; - samples.y_bar /= data.n_train; - for (int i = 0; i < data.n_train; i++) samples.y_std += (data.y_train[i] - samples.y_bar) * (data.y_train[i] - samples.y_bar); - samples.y_std = std::sqrt(samples.y_std / data.n_train); + samples.y_bar = y_mean; + samples.y_std = std::sqrt(y_var); } else { samples.y_bar = 0.0; samples.y_std = 1.0; @@ -74,7 +80,20 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART mean_forest_tracker_->UpdatePredictions(mean_forest_.get(), *forest_dataset_.get()); has_mean_forest_ = true; if (config.sigma2_mean_init < 0.0) { - config.sigma2_mean_init = (samples.y_std * samples.y_std) / config.num_trees_mean; + if (config.link_function == LinkFunction::Probit) { + config.sigma2_mean_init = 1.0 / config.num_trees_mean; + } else { + config.sigma2_mean_init = y_var / config.num_trees_mean; + } + } + if (sample_sigma2_leaf_) { + if (config.b_sigma2_mean <= 0.0) { + if (config.link_function == LinkFunction::Probit) { + config.b_sigma2_mean = 1.0 / (2 * config.num_trees_mean); + } else { + config.b_sigma2_mean = y_var / (2 * config.num_trees_mean); + } + } } } @@ -87,6 +106,17 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART variance_forest_->SetLeafValue(1.0 / config.num_trees_variance); variance_forest_tracker_->UpdatePredictions(variance_forest_.get(), *forest_dataset_.get()); has_variance_forest_ = true; + if (config.shape_variance_forest <= 0.0 || config.scale_variance_forest <= 0.0) { + if (config.leaf_prior_calibration_param <= 0.0) { + config.leaf_prior_calibration_param = 1.5; + } + if (config.shape_variance_forest <= 0.0) { + config.shape_variance_forest = config.num_trees_variance / (config.leaf_prior_calibration_param * config.leaf_prior_calibration_param) + 0.5; + } + if (config.scale_variance_forest <= 0.0) { + config.scale_variance_forest = config.num_trees_variance / (config.leaf_prior_calibration_param * config.leaf_prior_calibration_param); + } + } } // Global error variance model @@ -108,33 +138,37 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART global_variance_ = config.sigma2_global_init; leaf_scale_ = config.sigma2_mean_init; // leaf_scale_multivariate_ = config.sigma2_leaf_multivariate_init; + + initialized_ = true; } void BARTSampler::run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_gfr, bool keep_gfr) { // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance - GaussianConstantLeafModel leaf_model(leaf_scale_); + std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); + std::unique_ptr variance_leaf_model_ptr = std::make_unique(config.shape_variance_forest, config.scale_variance_forest); for (int i = 0; i < num_gfr; i++) { - RunOneIteration(samples, config, data, leaf_model, rng, /*gfr=*/true, /*keep_sample=*/keep_gfr); + RunOneIteration(samples, config, data, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), rng, /*gfr=*/true, /*keep_sample=*/keep_gfr); } } void BARTSampler::run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_burnin, int keep_every, int num_mcmc) { - GaussianConstantLeafModel leaf_model(leaf_scale_); + std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); + std::unique_ptr variance_leaf_model_ptr = std::make_unique(config.shape_variance_forest, config.scale_variance_forest); bool keep_forest = false; for (int i = 0; i < num_burnin + keep_every * num_mcmc; i++) { if (i >= num_burnin && (i - num_burnin) % keep_every == 0) keep_forest = true; else keep_forest = false; - RunOneIteration(samples, config, data, leaf_model, rng, /*gfr=*/false, /*keep_sample=*/keep_forest); + RunOneIteration(samples, config, data, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), rng, /*gfr=*/false, /*keep_sample=*/keep_forest); } } -void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel& leaf_model, std::mt19937& rng, bool gfr, bool keep_sample) { +void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, std::mt19937& rng, bool gfr, bool keep_sample) { if (has_mean_forest_) { if (gfr) { GFRSampleOneIter( - *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, leaf_model, + *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, *forest_dataset_, *residual_, *tree_prior_mean_, rng, config.var_weights_mean, config.sweep_update_indices, global_variance_, config.feature_types, config.cutpoint_grid_size, /*keep_forest=*/keep_sample, @@ -142,7 +176,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART /*num_features_subsample=*/data.p, config.num_threads); } else { MCMCSampleOneIter( - *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, leaf_model, + *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, *forest_dataset_, *residual_, *tree_prior_mean_, rng, config.var_weights_mean, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, @@ -150,6 +184,25 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART } } + if (has_variance_forest_) { + if (gfr) { + GFRSampleOneIter( + *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *variance_leaf_model, + *forest_dataset_, *residual_, *tree_prior_mean_, rng, + config.var_weights_mean, config.sweep_update_indices, global_variance_, config.feature_types, + config.cutpoint_grid_size, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/false, + /*num_features_subsample=*/data.p, config.num_threads); + } else { + MCMCSampleOneIter( + *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *variance_leaf_model, + *forest_dataset_, *residual_, *tree_prior_mean_, rng, + config.var_weights_mean, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/false, + /*num_threads=*/config.num_threads); + } + } + if (config.link_function == LinkFunction::Probit) { sample_probit_latent_outcome(rng_, outcome_raw_->GetData().data(), mean_forest_tracker_->GetSumPredictions(), residual_->GetData().data(), samples.y_bar, data.n_train); @@ -163,6 +216,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART if (sample_sigma2_leaf_) { leaf_scale_ = leaf_scale_model_->SampleVarianceParameter( mean_forest_.get(), config.a_sigma2_mean, config.b_sigma2_mean, rng_); + mean_leaf_model->SetScale(leaf_scale_); } if (keep_sample) { @@ -173,9 +227,11 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART double* mean_forest_preds_train = mean_forest_tracker_->GetSumPredictions(); samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), mean_forest_preds_train, mean_forest_preds_train + samples.num_train); - std::vector predictions = samples.mean_forests->GetEnsemble(samples.num_samples - 1)->Predict(*forest_dataset_test_); - samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), - predictions.data(), predictions.data() + samples.num_test); + if (has_test_) { + std::vector predictions = samples.mean_forests->GetEnsemble(samples.num_samples - 1)->Predict(*forest_dataset_test_); + samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), + predictions.data(), predictions.data() + samples.num_test); + } } } } From d7c374124cb2b3b649cf6a2329e6b8b452ccdbe0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 10 Apr 2026 11:20:36 -0400 Subject: [PATCH 26/64] Correctly reference variance model terms --- src/bart_sampler.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 8992fc6b..75e90e81 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -187,17 +187,17 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART if (has_variance_forest_) { if (gfr) { GFRSampleOneIter( - *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *variance_leaf_model, - *forest_dataset_, *residual_, *tree_prior_mean_, rng, - config.var_weights_mean, config.sweep_update_indices, global_variance_, config.feature_types, + *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, + *forest_dataset_, *residual_, *tree_prior_variance_, rng, + config.var_weights_variance, config.sweep_update_indices, global_variance_, config.feature_types, config.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, /*num_features_subsample=*/data.p, config.num_threads); } else { MCMCSampleOneIter( - *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *variance_leaf_model, - *forest_dataset_, *residual_, *tree_prior_mean_, rng, - config.var_weights_mean, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, + *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, + *forest_dataset_, *residual_, *tree_prior_variance_, rng, + config.var_weights_variance, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, /*num_threads=*/config.num_threads); } From 4390de111c2fb571d9cde823b14c12bebbec54b8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 10 Apr 2026 18:04:47 -0400 Subject: [PATCH 27/64] Initial implementation of R wrapper around the new BARTSampler interface --- R/bart.R | 2645 +++++++++++++++------------- R/cpp11.R | 4 + debug/bart_debug.cpp | 4 +- debug/benchmark_cpp_vs_r_sampler.R | 151 ++ include/stochtree/bart_sampler.h | 6 +- man/bart.Rd | 5 +- src/Makevars.in | 2 + src/R_bart.cpp | 253 +++ src/bart_sampler.cpp | 21 +- src/cpp11.cpp | 8 + src/stochtree_types.h | 2 + 11 files changed, 1816 insertions(+), 1285 deletions(-) create mode 100644 debug/benchmark_cpp_vs_r_sampler.R create mode 100644 src/R_bart.cpp diff --git a/R/bart.R b/R/bart.R index e27e945f..e92a0c48 100644 --- a/R/bart.R +++ b/R/bart.R @@ -157,6 +157,8 @@ NULL #' - `variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. #' - `variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. #' +#' @param run_cpp Whether or not to run the core C++ sampler. This is exposed as an argument for testing purposes, but in general should be left as `TRUE`. If `FALSE`, the function will run the previous version of the BART sampler in which the core loop logic was implemented in R, with C++ calls for most computationally intensive steps. +#' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export #' @@ -203,7 +205,8 @@ bart <- function( general_params = list(), mean_forest_params = list(), variance_forest_params = list(), - random_effects_params = list() + random_effects_params = list(), + run_cpp = TRUE ) { # Update general BART parameters general_params_default <- list( @@ -388,19 +391,6 @@ bart <- function( )) } - # Set a function-scoped RNG if user provided a random seed - custom_rng <- random_seed >= 0 - has_existing_random_seed <- F - if (custom_rng) { - # Cache original global environment RNG state (if it exists) - if (exists(".Random.seed", envir = .GlobalEnv)) { - original_global_seed <- .Random.seed - has_existing_random_seed <- T - } - # Set new seed and store associated RNG state - set.seed(random_seed) - } - # Check if there are enough GFR samples to seed num_chains samplers if (num_gfr > 0) { if (num_chains > num_gfr) { @@ -1055,163 +1045,6 @@ bart <- function( } } - # Handle standardization, prior calibration, and initialization of forest - # differently for binary and continuous outcomes - if (link_is_probit) { - # Probit-scale intercept: center the forest on the population-average latent mean. - # The forest predicts mu(X) and y_bar_train is added back at prediction time. - # The latent z sampling uses y_bar_train to set the correct truncated normal mean and to center z before the residual update. - y_bar_train <- qnorm(mean_cpp(as.numeric(y_train))) - y_std_train <- 1 - standardize <- FALSE - - # Set a pseudo outcome by subtracting mean_cpp(y_train) from y_train - resid_train <- y_train - mean_cpp(as.numeric(y_train)) - - # Set initial values of root nodes to 0.0 (in probit scale) - init_val_mean <- 0.0 - - # Calibrate priors for sigma^2 and tau - # Set sigma2_init to 1, ignoring default provided - sigma2_init <- 1.0 - # Skip variance_forest_init, since variance forests are not supported with probit link - if (is.null(b_leaf)) { - b_leaf <- 1 / (num_trees_mean) - } - if (has_basis) { - if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- diag( - 2 / (num_trees_mean), - ncol(leaf_basis_train) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag( - sigma2_leaf_init, - ncol(leaf_basis_train) - )) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } - current_sigma2 <- sigma2_init - } else if (link_is_cloglog) { - # Fix offset to 0 and scale to 1 - y_bar_train <- 0 - y_std_train <- 1 - standardize <- FALSE - - # Remap outcomes to start from 0 - resid_train <- as.numeric(y_train - min(unique_outcomes)) - cloglog_num_categories <- max(resid_train) + 1 - - # Set initial values of root nodes to 0.0 (in linear scale) - init_val_mean <- 0.0 - - # Calibrate priors for sigma^2 and tau - # Set sigma2_init to 1, ignoring default provided - sigma2_init <- 1.0 - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) - } - current_sigma2 <- sigma2_init - current_leaf_scale <- sigma2_leaf_init - - # Set first cutpoint to 0 for identifiability - cloglog_cutpoint_0 <- 0 - - # Set shape and rate parameters for conditional gamma model - cloglog_forest_shape <- 2.0 - cloglog_forest_rate <- 2.0 - } else { - # Only standardize if user requested - if (standardize) { - y_bar_train <- mean_cpp(as.numeric(y_train)) - y_std_train <- sd_cpp(as.numeric(y_train)) - } else { - y_bar_train <- 0 - y_std_train <- 1 - } - - # Compute standardized outcome - resid_train <- (y_train - y_bar_train) / y_std_train - - # Compute initial value of root nodes in mean forest - init_val_mean <- mean_cpp(as.numeric(resid_train)) - - # Calibrate priors for sigma^2 and tau - if (is.null(sigma2_init)) { - sigma2_init <- 1.0 * var_cpp(as.numeric(resid_train)) - } - if (is.null(variance_forest_init)) { - variance_forest_init <- 1.0 * var_cpp(as.numeric(resid_train)) - } - if (is.null(b_leaf)) { - b_leaf <- var_cpp(as.numeric(resid_train)) / (2 * num_trees_mean) - } - if (has_basis) { - if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- diag( - 2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean), - ncol(leaf_basis_train) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag( - sigma2_leaf_init, - ncol(leaf_basis_train) - )) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix( - 2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix( - 2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } - current_sigma2 <- sigma2_init - } - # Determine leaf model type if ((!has_basis) && (!link_is_cloglog)) { leaf_model_mean_forest <- 0 @@ -1257,698 +1090,853 @@ bart <- function( leaf_regression = FALSE } - # Data - if (leaf_regression) { - forest_dataset_train <- createForestDataset( - X_train, - leaf_basis_train, - observation_weights + if (run_cpp) { + # Specify the BART config + bart_config <- list( + "standardize_outcome" = standardize, + "num_threads" = num_threads, + "cutpoint_grid_size" = cutpoint_grid_size, + "link_function" = ifelse( + outcome_model$link == "identity", + 0, + ifelse(outcome_model$link == "probit", 1, 2) + ), + "outcome_type" = ifelse( + outcome_model$outcome == "continuous", + 0, + ifelse(outcome_model$outcome == "binary", 1, 2) + ), + "random_seed" = random_seed, + "a_sigma2_global" = a_global, + "b_sigma2_global" = b_global, + "sigma2_global_init" = sigma2_init, + "sample_sigma2_global" = sample_sigma2_global, + "num_trees_mean" = num_trees_mean, + "alpha_mean" = alpha_mean, + "beta_mean" = beta_mean, + "min_samples_leaf_mean" = min_samples_leaf_mean, + "max_depth_mean" = max_depth_mean, + "leaf_constant_mean" = is_leaf_constant, + "leaf_dim_mean" = leaf_dimension, + "exponentiated_leaf_mean" = FALSE, + "num_features_subsample_mean" = num_features_subsample_mean, + "a_sigma2_mean" = a_leaf, + "b_sigma2_mean" = b_leaf, + "sigma2_mean_init" = sigma2_leaf_init, + "sample_sigma2_leaf_mean" = sample_sigma2_leaf, + "num_trees_variance" = num_trees_variance, + "leaf_prior_calibration_param" = a_0, + "shape_variance_forest" = a_forest, + "scale_variance_forest" = b_forest, + "alpha_variance" = alpha_variance, + "beta_variance" = beta_variance, + "min_samples_leaf_variance" = min_samples_leaf_variance, + "max_depth_variance" = max_depth_variance, + "leaf_constant_variance" = TRUE, + "leaf_dim_variance" = 1, + "exponentiated_leaf_variance" = TRUE, + "num_features_subsample_variance" = num_features_subsample_variance, + "feature_types" = feature_types, + "sweep_update_indices" = 0:(ncol(X_train) - 1), + "var_weights_mean" = variable_weights_mean, + "var_weights_variance" = variable_weights_variance ) - if (has_test) { - forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) - } - requires_basis <- TRUE - } else { - forest_dataset_train <- createForestDataset( - X_train, - variance_weights = observation_weights + + bart_results <- bart_sample_cpp( + X_train = X_train, + y_train = y_train, + X_test = if (exists("X_test")) X_test else NULL, + n_train = nrow(X_train), + n_test = if (!is.null(X_test)) nrow(X_test) else 0L, + p = ncol(X_train), + basis_train = if (exists("leaf_basis_train")) leaf_basis_train else NULL, + basis_test = if (exists("leaf_basis_test")) leaf_basis_test else NULL, + basis_dim = if (!is.null(leaf_basis_train)) ncol(leaf_basis_train) else 0L, + obs_weights_train = if (exists("obs_weights_train")) { + obs_weights_train + } else { + NULL + }, + obs_weights_test = if (exists("obs_weights_test")) { + obs_weights_test + } else { + NULL + }, + rfx_group_ids_train = if (exists("rfx_group_ids_train")) { + rfx_group_ids_train + } else { + NULL + }, + rfx_group_ids_test = if (exists("rfx_group_ids_test")) { + rfx_group_ids_test + } else { + NULL + }, + rfx_basis_train = if (exists("rfx_basis_train")) { + rfx_basis_train + } else { + NULL + }, + rfx_basis_test = if (exists("rfx_basis_test")) rfx_basis_test else NULL, + rfx_num_groups = if (exists("num_rfx_groups")) as.integer(num_rfx_groups) else 0L, + rfx_basis_dim = as.integer(num_basis_rfx), + num_gfr = as.integer(num_gfr), + num_burnin = as.integer(num_burnin), + keep_every = as.integer(keep_every), + num_mcmc = as.integer(num_mcmc), + config_input = bart_config ) - if (has_test) { - forest_dataset_test <- createForestDataset(X_test) + result <- bart_results + class(result) <- "bartmodel" + } else { + # Set a function-scoped RNG if user provided a random seed + custom_rng <- random_seed >= 0 + has_existing_random_seed <- F + if (custom_rng) { + # Cache original global environment RNG state (if it exists) + if (exists(".Random.seed", envir = .GlobalEnv)) { + original_global_seed <- .Random.seed + has_existing_random_seed <- T + } + # Set new seed and store associated RNG state + set.seed(random_seed) } - requires_basis <- FALSE - } - outcome_train <- createOutcome(resid_train) - - # Random number generator (std::mt19937) - if (is.null(random_seed)) { - random_seed = sample(1:10000, 1, FALSE) - } - rng <- createCppRNG(random_seed) - # Separate ordinal sampler object for cloglog - if (link_is_cloglog) { - ordinal_sampler <- ordinal_sampler_cpp() - } + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if (link_is_probit) { + # Probit-scale intercept: center the forest on the population-average latent mean. + # The forest predicts mu(X) and y_bar_train is added back at prediction time. + # The latent z sampling uses y_bar_train to set the correct truncated normal mean and to center z before the residual update. + y_bar_train <- qnorm(mean_cpp(as.numeric(y_train))) + y_std_train <- 1 + standardize <- FALSE - # Sampling data structures - feature_types <- as.integer(feature_types) - global_model_config <- createGlobalModelConfig( - global_error_variance = current_sigma2 - ) - if (include_mean_forest) { - forest_model_config_mean <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_mean, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_mean, - leaf_dimension = leaf_dimension, - alpha = alpha_mean, - beta = beta_mean, - min_samples_leaf = min_samples_leaf_mean, - max_depth = max_depth_mean, - leaf_model_type = leaf_model_mean_forest, - leaf_model_scale = current_leaf_scale, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_mean - ) - if (link_is_cloglog) { - forest_model_config_mean$update_cloglog_forest_shape(cloglog_forest_shape) - forest_model_config_mean$update_cloglog_forest_rate(cloglog_forest_rate) - } - forest_model_mean <- createForestModel( - forest_dataset_train, - forest_model_config_mean, - global_model_config - ) - } - if (include_variance_forest) { - forest_model_config_variance <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_variance, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_variance, - leaf_dimension = 1, - alpha = alpha_variance, - beta = beta_variance, - min_samples_leaf = min_samples_leaf_variance, - max_depth = max_depth_variance, - leaf_model_type = leaf_model_variance_forest, - variance_forest_shape = a_forest, - variance_forest_scale = b_forest, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_variance - ) - forest_model_variance <- createForestModel( - forest_dataset_train, - forest_model_config_variance, - global_model_config - ) - } + # Set a pseudo outcome by subtracting mean_cpp(y_train) from y_train + resid_train <- y_train - mean_cpp(as.numeric(y_train)) - # Container of forest samples - if (include_mean_forest) { - forest_samples_mean <- createForestSamples( - num_trees_mean, - leaf_dimension, - is_leaf_constant, - FALSE - ) - active_forest_mean <- createForest( - num_trees_mean, - leaf_dimension, - is_leaf_constant, - FALSE - ) - } - if (include_variance_forest) { - forest_samples_variance <- createForestSamples( - num_trees_variance, - 1, - TRUE, - TRUE - ) - active_forest_variance <- createForest( - num_trees_variance, - 1, - TRUE, - TRUE - ) - } + # Set initial values of root nodes to 0.0 (in probit scale) + init_val_mean <- 0.0 - # Random effects initialization - if (has_rfx) { - # Prior parameters - if (is.null(rfx_working_parameter_prior_mean)) { - if (num_rfx_components == 1) { - alpha_init <- c(0) - } else if (num_rfx_components > 1) { - alpha_init <- rep(0, num_rfx_components) + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init <- 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link + if (is.null(b_leaf)) { + b_leaf <- 1 / (num_trees_mean) + } + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- diag( + 2 / (num_trees_mean), + ncol(leaf_basis_train) + ) + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag( + sigma2_leaf_init, + ncol(leaf_basis_train) + )) + } else { + current_leaf_scale <- sigma2_leaf_init + } + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init + } + } } else { - stop("There must be at least 1 random effect component") + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init + } } - } else { - alpha_init <- expand_dims_1d( - rfx_working_parameter_prior_mean, - num_rfx_components - ) - } - - if (is.null(rfx_group_parameter_prior_mean)) { - xi_init <- matrix( - rep(alpha_init, num_rfx_groups), - num_rfx_components, - num_rfx_groups - ) - } else { - xi_init <- expand_dims_2d( - rfx_group_parameter_prior_mean, - num_rfx_components, - num_rfx_groups - ) - } - - if (is.null(rfx_working_parameter_prior_cov)) { - sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_alpha_init <- expand_dims_2d_diag( - rfx_working_parameter_prior_cov, - num_rfx_components - ) - } + current_sigma2 <- sigma2_init + } else if (link_is_cloglog) { + # Fix offset to 0 and scale to 1 + y_bar_train <- 0 + y_std_train <- 1 + standardize <- FALSE - if (is.null(rfx_group_parameter_prior_cov)) { - sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_xi_init <- expand_dims_2d_diag( - rfx_group_parameter_prior_cov, - num_rfx_components - ) - } + # Remap outcomes to start from 0 + resid_train <- as.numeric(y_train - min(unique_outcomes)) + cloglog_num_categories <- max(resid_train) + 1 - sigma_xi_shape <- rfx_variance_prior_shape - sigma_xi_scale <- rfx_variance_prior_scale + # Set initial values of root nodes to 0.0 (in linear scale) + init_val_mean <- 0.0 - # Random effects data structure and storage container - rfx_dataset_train <- createRandomEffectsDataset( - rfx_group_ids_train, - rfx_basis_train - ) - rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) - rfx_model <- createRandomEffectsModel( - num_rfx_components, - num_rfx_groups - ) - rfx_model$set_working_parameter(alpha_init) - rfx_model$set_group_parameters(xi_init) - rfx_model$set_working_parameter_cov(sigma_alpha_init) - rfx_model$set_group_parameter_cov(sigma_xi_init) - rfx_model$set_variance_prior_shape(sigma_xi_shape) - rfx_model$set_variance_prior_scale(sigma_xi_scale) - rfx_samples <- createRandomEffectSamples( - num_rfx_components, - num_rfx_groups, - rfx_tracker_train - ) - } + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init <- 1.0 + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) + } + current_sigma2 <- sigma2_init + current_leaf_scale <- sigma2_leaf_init - # Container of parameter samples - num_actual_mcmc_iter <- num_mcmc * keep_every - num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter - # Delete GFR samples from these containers after the fact if desired - # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc - num_retained_samples <- num_gfr + - ifelse(keep_burnin, num_burnin, 0) + - num_mcmc * num_chains - if (sample_sigma2_global) { - global_var_samples <- rep(NA, num_retained_samples) - } - if (sample_sigma2_leaf) { - leaf_scale_samples <- rep(NA, num_retained_samples) - } - if (link_is_cloglog) { - cloglog_cutpoint_samples <- matrix( - NA_real_, - cloglog_num_categories - 1, - num_retained_samples - ) - } - if (include_mean_forest) { - mean_forest_pred_train <- matrix( - NA_real_, - nrow(X_train), - num_retained_samples - ) - } - if (include_variance_forest) { - variance_forest_pred_train <- matrix( - NA_real_, - nrow(X_train), - num_retained_samples - ) - } - sample_counter <- 0 + # Set first cutpoint to 0 for identifiability + cloglog_cutpoint_0 <- 0 - # Initialize the leaves of each tree in the mean forest - if (include_mean_forest) { - if (requires_basis) { - # Handle the case in which we must initialize root values in a leaf basis regression - # when init_val_mean != 0. To do this, we regress rep(init_val_mean, nrow(y_train)) - # on leaf_basis_train and use (coefs / num_trees_mean) as initial values - if (abs(init_val_mean) > 0.00001) { - init_val_y <- rep(init_val_mean, nrow(y_train)) - init_val_model <- lm(init_val_y ~ 0 + leaf_basis_train) - init_values_mean_forest <- coef(init_val_model) - if (any(is.na(init_values_mean_forest))) { - init_values_mean_forest[which(is.na(init_values_mean_forest))] <- 0. - } + # Set shape and rate parameters for conditional gamma model + cloglog_forest_shape <- 2.0 + cloglog_forest_rate <- 2.0 + } else { + # Only standardize if user requested + if (standardize) { + y_bar_train <- mean_cpp(as.numeric(y_train)) + y_std_train <- sd_cpp(as.numeric(y_train)) } else { - init_values_mean_forest <- rep(init_val_mean, ncol(leaf_basis_train)) + y_bar_train <- 0 + y_std_train <- 1 } - } else { - init_values_mean_forest <- init_val_mean - } - active_forest_mean$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_mean, - leaf_model_mean_forest, - init_values_mean_forest - ) - } - # Initialize the leaves of each tree in the variance forest - if (include_variance_forest) { - active_forest_variance$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_variance, - leaf_model_variance_forest, - variance_forest_init - ) - } + # Compute standardized outcome + resid_train <- (y_train - y_bar_train) / y_std_train - # Initialize auxiliary data for cloglog - if (link_is_cloglog) { - ## Allocate auxiliary data - train_size <- nrow(X_train) - # Latent variable (Z in Alam et al (2025) notation) - forest_dataset_train$add_auxiliary_dimension(train_size) - # Forest predictions (eta in Alam et al (2025) notation) - forest_dataset_train$add_auxiliary_dimension(train_size) - # Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation) - forest_dataset_train$add_auxiliary_dimension(cloglog_num_categories - 1) - # Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation) - # This auxiliary series is designed so that the element stored at position `i` - # corresponds to the sum of all exponentiated gamma_j values for j < i. - # It has cloglog_num_categories elements instead of cloglog_num_categories - 1 because - # even the largest categorical index has a valid value of sum_{j < i} exp(gamma_j) - forest_dataset_train$add_auxiliary_dimension(cloglog_num_categories) - - ## Set initial values for auxiliary data - # Initialize latent variables to zero (slot 0) - for (i in 1:train_size) { - forest_dataset_train$set_auxiliary_data_value(0, i - 1, 0.0) - } - # Initialize forest predictions to zero (slot 1) - for (i in 1:train_size) { - forest_dataset_train$set_auxiliary_data_value(1, i - 1, 0.0) - } - # Initialize log-scale cutpoints to 0 - initial_gamma <- rep(0.0, cloglog_num_categories - 1) - for (i in seq_along(initial_gamma)) { - forest_dataset_train$set_auxiliary_data_value(2, i - 1, initial_gamma[i]) - } - # Convert to cumulative exponentiated cutpoints directly in C++ - ordinal_sampler_update_cumsum_exp_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr - ) - } + # Compute initial value of root nodes in mean forest + init_val_mean <- mean_cpp(as.numeric(resid_train)) - # Run GFR (warm start) if specified - if (num_gfr > 0) { - for (i in 1:num_gfr) { - # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC - # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) - keep_sample <- TRUE - if (keep_sample) { - sample_counter <- sample_counter + 1 + # Calibrate priors for sigma^2 and tau + if (is.null(sigma2_init)) { + sigma2_init <- 1.0 * var_cpp(as.numeric(resid_train)) } - # Print progress - if (verbose) { - if ((i %% 10 == 0) || (i == num_gfr)) { - cat( - "Sampling", - i, - "out of", - num_gfr, - "XBART (grow-from-root) draws\n" - ) - } + if (is.null(variance_forest_init)) { + variance_forest_init <- 1.0 * var_cpp(as.numeric(resid_train)) } - - if (include_mean_forest) { - if (link_is_probit) { - # Sample latent probit variable, z | - - # outcome_pred is the centered forest prediction (not including y_bar_train). - # The truncated normal mean is outcome_pred + y_bar_train (the full eta on the probit scale). - # The residual stored is z - y_bar_train - outcome_pred so the forest sees a - # zero-centered signal and the prior shrinkage toward 0 is well-calibrated. - outcome_pred <- active_forest_mean$predict( - forest_dataset_train - ) - if (has_rfx) { - rfx_pred <- rfx_model$predict( - rfx_dataset_train, - rfx_tracker_train + if (is.null(b_leaf)) { + b_leaf <- var_cpp(as.numeric(resid_train)) / (2 * num_trees_mean) + } + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- diag( + 2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean), + ncol(leaf_basis_train) ) - outcome_pred <- outcome_pred + rfx_pred } - eta_pred <- outcome_pred + y_bar_train - mu0 <- eta_pred[y_train == 0] - mu1 <- eta_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome: center z by y_bar_train before passing to forest - outcome_train$update_data(resid_train - y_bar_train - outcome_pred) - } - - # Sample mean forest - forest_model_mean$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mean, - active_forest = active_forest_mean, - rng = rng, - forest_model_config = forest_model_config_mean, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - mean_forest_pred_train[, - sample_counter - ] <- forest_model_mean$get_cached_forest_predictions() - } - - # Additional Gibbs updates needed for the cloglog model - if (link_is_cloglog) { - # Update auxiliary data to current forest predictions - forest_pred_current <- forest_model_mean$get_cached_forest_predictions() - for (i in 1:train_size) { - forest_dataset_train$set_auxiliary_data_value( - 1, - i - 1, - forest_pred_current[i] - ) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag( + sigma2_leaf_init, + ncol(leaf_basis_train) + )) + } else { + current_leaf_scale <- sigma2_leaf_init } - - # Sample latent z_i's using truncated exponential - ordinal_sampler_update_latent_variables_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr, - outcome_train$data_ptr, - rng$rng_ptr - ) - - # Sample gamma parameters (cutpoints) - ordinal_sampler_update_gamma_params_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr, - outcome_train$data_ptr, - cloglog_forest_shape, - cloglog_forest_rate, - cloglog_cutpoint_0, - rng$rng_ptr - ) - - # Update cumulative sum of exp(gamma) values - ordinal_sampler_update_cumsum_exp_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr - ) - - # Retain cutpoint draw - if (keep_sample) { - cloglog_cutpoints <- forest_dataset_train$get_auxiliary_data_vector( - 2 + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix( + 2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean) ) - cloglog_cutpoint_samples[, sample_counter] <- cloglog_cutpoints + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init } } - } - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - variance_forest_pred_train[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix( + 2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean) + ) } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init } - global_model_config$update_global_error_variance(current_sigma2) } - if (sample_sigma2_leaf) { - leaf_scale_double <- sampleLeafVarianceOneIteration( - active_forest_mean, - rng, - a_leaf, - b_leaf - ) - current_leaf_scale <- as.matrix(leaf_scale_double) - if (keep_sample) { - leaf_scale_samples[sample_counter] <- leaf_scale_double - } - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) + current_sigma2 <- sigma2_init + } + + # Data + if (leaf_regression) { + forest_dataset_train <- createForestDataset( + X_train, + leaf_basis_train, + observation_weights + ) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) } - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) + requires_basis <- TRUE + } else { + forest_dataset_train <- createForestDataset( + X_train, + variance_weights = observation_weights + ) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test) } + requires_basis <- FALSE } - } + outcome_train <- createOutcome(resid_train) - # Run MCMC - if (num_burnin + num_mcmc > 0) { - for (chain_num in 1:num_chains) { - if (verbose) { - cat("Sampling chain", chain_num, "of", num_chains, "\n") - } - if (num_gfr > 0) { - # Reset state of active_forest and forest_model based on a previous GFR sample - forest_ind <- num_gfr - chain_num - if (include_mean_forest) { - resetActiveForest( - active_forest_mean, - forest_samples_mean, - forest_ind - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf) { - leaf_scale_double <- leaf_scale_samples[forest_ind + 1] - current_leaf_scale <- as.matrix(leaf_scale_double) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale + # Random number generator (std::mt19937) + if (is.null(random_seed)) { + random_seed = sample(1:10000, 1, FALSE) + } + rng <- createCppRNG(random_seed) + + # Separate ordinal sampler object for cloglog + if (link_is_cloglog) { + ordinal_sampler <- ordinal_sampler_cpp() + } + + # Sampling data structures + feature_types <- as.integer(feature_types) + global_model_config <- createGlobalModelConfig( + global_error_variance = current_sigma2 + ) + if (include_mean_forest) { + forest_model_config_mean <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_mean, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_mean, + leaf_dimension = leaf_dimension, + alpha = alpha_mean, + beta = beta_mean, + min_samples_leaf = min_samples_leaf_mean, + max_depth = max_depth_mean, + leaf_model_type = leaf_model_mean_forest, + leaf_model_scale = current_leaf_scale, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_mean + ) + if (link_is_cloglog) { + forest_model_config_mean$update_cloglog_forest_shape( + cloglog_forest_shape + ) + forest_model_config_mean$update_cloglog_forest_rate(cloglog_forest_rate) + } + forest_model_mean <- createForestModel( + forest_dataset_train, + forest_model_config_mean, + global_model_config + ) + } + if (include_variance_forest) { + forest_model_config_variance <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_variance, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_variance, + leaf_dimension = 1, + alpha = alpha_variance, + beta = beta_variance, + min_samples_leaf = min_samples_leaf_variance, + max_depth = max_depth_variance, + leaf_model_type = leaf_model_variance_forest, + variance_forest_shape = a_forest, + variance_forest_scale = b_forest, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_variance + ) + forest_model_variance <- createForestModel( + forest_dataset_train, + forest_model_config_variance, + global_model_config + ) + } + + # Container of forest samples + if (include_mean_forest) { + forest_samples_mean <- createForestSamples( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) + active_forest_mean <- createForest( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) + } + if (include_variance_forest) { + forest_samples_variance <- createForestSamples( + num_trees_variance, + 1, + TRUE, + TRUE + ) + active_forest_variance <- createForest( + num_trees_variance, + 1, + TRUE, + TRUE + ) + } + + # Random effects initialization + if (has_rfx) { + # Prior parameters + if (is.null(rfx_working_parameter_prior_mean)) { + if (num_rfx_components == 1) { + alpha_init <- c(0) + } else if (num_rfx_components > 1) { + alpha_init <- rep(0, num_rfx_components) + } else { + stop("There must be at least 1 random effect component") + } + } else { + alpha_init <- expand_dims_1d( + rfx_working_parameter_prior_mean, + num_rfx_components + ) + } + + if (is.null(rfx_group_parameter_prior_mean)) { + xi_init <- matrix( + rep(alpha_init, num_rfx_groups), + num_rfx_components, + num_rfx_groups + ) + } else { + xi_init <- expand_dims_2d( + rfx_group_parameter_prior_mean, + num_rfx_components, + num_rfx_groups + ) + } + + if (is.null(rfx_working_parameter_prior_cov)) { + sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) + } else { + sigma_alpha_init <- expand_dims_2d_diag( + rfx_working_parameter_prior_cov, + num_rfx_components + ) + } + + if (is.null(rfx_group_parameter_prior_cov)) { + sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) + } else { + sigma_xi_init <- expand_dims_2d_diag( + rfx_group_parameter_prior_cov, + num_rfx_components + ) + } + + sigma_xi_shape <- rfx_variance_prior_shape + sigma_xi_scale <- rfx_variance_prior_scale + + # Random effects data structure and storage container + rfx_dataset_train <- createRandomEffectsDataset( + rfx_group_ids_train, + rfx_basis_train + ) + rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) + rfx_model <- createRandomEffectsModel( + num_rfx_components, + num_rfx_groups + ) + rfx_model$set_working_parameter(alpha_init) + rfx_model$set_group_parameters(xi_init) + rfx_model$set_working_parameter_cov(sigma_alpha_init) + rfx_model$set_group_parameter_cov(sigma_xi_init) + rfx_model$set_variance_prior_shape(sigma_xi_shape) + rfx_model$set_variance_prior_scale(sigma_xi_scale) + rfx_samples <- createRandomEffectSamples( + num_rfx_components, + num_rfx_groups, + rfx_tracker_train + ) + } + + # Container of parameter samples + num_actual_mcmc_iter <- num_mcmc * keep_every + num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter + # Delete GFR samples from these containers after the fact if desired + # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc + num_retained_samples <- num_gfr + + ifelse(keep_burnin, num_burnin, 0) + + num_mcmc * num_chains + if (sample_sigma2_global) { + global_var_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf) { + leaf_scale_samples <- rep(NA, num_retained_samples) + } + if (link_is_cloglog) { + cloglog_cutpoint_samples <- matrix( + NA_real_, + cloglog_num_categories - 1, + num_retained_samples + ) + } + if (include_mean_forest) { + mean_forest_pred_train <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } + if (include_variance_forest) { + variance_forest_pred_train <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } + sample_counter <- 0 + + # Initialize the leaves of each tree in the mean forest + if (include_mean_forest) { + if (requires_basis) { + # Handle the case in which we must initialize root values in a leaf basis regression + # when init_val_mean != 0. To do this, we regress rep(init_val_mean, nrow(y_train)) + # on leaf_basis_train and use (coefs / num_trees_mean) as initial values + if (abs(init_val_mean) > 0.00001) { + init_val_y <- rep(init_val_mean, nrow(y_train)) + init_val_model <- lm(init_val_y ~ 0 + leaf_basis_train) + init_values_mean_forest <- coef(init_val_model) + if (any(is.na(init_values_mean_forest))) { + init_values_mean_forest[which(is.na(init_values_mean_forest))] <- 0. + } + } else { + init_values_mean_forest <- rep(init_val_mean, ncol(leaf_basis_train)) + } + } else { + init_values_mean_forest <- init_val_mean + } + active_forest_mean$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_mean, + leaf_model_mean_forest, + init_values_mean_forest + ) + } + + # Initialize the leaves of each tree in the variance forest + if (include_variance_forest) { + active_forest_variance$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_variance, + leaf_model_variance_forest, + variance_forest_init + ) + } + + # Initialize auxiliary data for cloglog + if (link_is_cloglog) { + ## Allocate auxiliary data + train_size <- nrow(X_train) + # Latent variable (Z in Alam et al (2025) notation) + forest_dataset_train$add_auxiliary_dimension(train_size) + # Forest predictions (eta in Alam et al (2025) notation) + forest_dataset_train$add_auxiliary_dimension(train_size) + # Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation) + forest_dataset_train$add_auxiliary_dimension(cloglog_num_categories - 1) + # Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation) + # This auxiliary series is designed so that the element stored at position `i` + # corresponds to the sum of all exponentiated gamma_j values for j < i. + # It has cloglog_num_categories elements instead of cloglog_num_categories - 1 because + # even the largest categorical index has a valid value of sum_{j < i} exp(gamma_j) + forest_dataset_train$add_auxiliary_dimension(cloglog_num_categories) + + ## Set initial values for auxiliary data + # Initialize latent variables to zero (slot 0) + for (i in 1:train_size) { + forest_dataset_train$set_auxiliary_data_value(0, i - 1, 0.0) + } + # Initialize forest predictions to zero (slot 1) + for (i in 1:train_size) { + forest_dataset_train$set_auxiliary_data_value(1, i - 1, 0.0) + } + # Initialize log-scale cutpoints to 0 + initial_gamma <- rep(0.0, cloglog_num_categories - 1) + for (i in seq_along(initial_gamma)) { + forest_dataset_train$set_auxiliary_data_value( + 2, + i - 1, + initial_gamma[i] + ) + } + # Convert to cumulative exponentiated cutpoints directly in C++ + ordinal_sampler_update_cumsum_exp_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr + ) + } + + # Run GFR (warm start) if specified + if (num_gfr > 0) { + for (i in 1:num_gfr) { + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) + keep_sample <- TRUE + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if ((i %% 10 == 0) || (i == num_gfr)) { + cat( + "Sampling", + i, + "out of", + num_gfr, + "XBART (grow-from-root) draws\n" ) } + } + + if (include_mean_forest) { + if (link_is_probit) { + # Sample latent probit variable, z | - + # outcome_pred is the centered forest prediction (not including y_bar_train). + # The truncated normal mean is outcome_pred + y_bar_train (the full eta on the probit scale). + # The residual stored is z - y_bar_train - outcome_pred so the forest sees a + # zero-centered signal and the prior shrinkage toward 0 is well-calibrated. + outcome_pred <- active_forest_mean$predict( + forest_dataset_train + ) + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + outcome_pred <- outcome_pred + rfx_pred + } + eta_pred <- outcome_pred + y_bar_train + mu0 <- eta_pred[y_train == 0] + mu1 <- eta_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome: center z by y_bar_train before passing to forest + outcome_train$update_data(resid_train - y_bar_train - outcome_pred) + } + + # Sample mean forest + forest_model_mean$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mean, + active_forest = active_forest_mean, + rng = rng, + forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + mean_forest_pred_train[, + sample_counter + ] <- forest_model_mean$get_cached_forest_predictions() + } + + # Additional Gibbs updates needed for the cloglog model if (link_is_cloglog) { - # Restore ordinal labels corrupted by resetForestModel's - # residual adjustment (outcome stores category labels, not residuals) - outcome_train$update_data(resid_train) - # We can reset cutpoints from warm-start since cutpoints are retained - current_cutpoints <- cloglog_cutpoint_samples[, forest_ind + 1] - for (i in seq_along(current_cutpoints)) { + # Update auxiliary data to current forest predictions + forest_pred_current <- forest_model_mean$get_cached_forest_predictions() + for (i in 1:train_size) { forest_dataset_train$set_auxiliary_data_value( - 2, + 1, i - 1, - current_cutpoints[i] + forest_pred_current[i] ) } + + # Sample latent z_i's using truncated exponential + ordinal_sampler_update_latent_variables_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr, + outcome_train$data_ptr, + rng$rng_ptr + ) + + # Sample gamma parameters (cutpoints) + ordinal_sampler_update_gamma_params_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr, + outcome_train$data_ptr, + cloglog_forest_shape, + cloglog_forest_rate, + cloglog_cutpoint_0, + rng$rng_ptr + ) + + # Update cumulative sum of exp(gamma) values ordinal_sampler_update_cumsum_exp_cpp( ordinal_sampler, forest_dataset_train$data_ptr ) - # Re-predict from the reconstituted active forest - active_forest_preds <- active_forest_mean$predict( - forest_dataset_train - ) - for (i in 1:train_size) { - forest_dataset_train$set_auxiliary_data_value( - 1, - i - 1, - active_forest_preds[i] + + # Retain cutpoint draw + if (keep_sample) { + cloglog_cutpoints <- forest_dataset_train$get_auxiliary_data_vector( + 2 ) - # Latent variables must be reset to 0 and burnt in - forest_dataset_train$set_auxiliary_data_value(0, i - 1, 0.0) + cloglog_cutpoint_samples[, sample_counter] <- cloglog_cutpoints } } } if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - forest_samples_variance, - forest_ind + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + variance_forest_pred_train[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() + } + } + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome_train, - FALSE + forest_dataset_train, + rng, + a_global, + b_global ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance(current_sigma2) } - if (has_rfx) { - resetRandomEffectsModel( - rfx_model, - rfx_samples, - forest_ind, - sigma_alpha_init + if (sample_sigma2_leaf) { + leaf_scale_double <- sampleLeafVarianceOneIteration( + active_forest_mean, + rng, + a_leaf, + b_leaf ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) { + leaf_scale_samples[sample_counter] <- leaf_scale_double + } + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) + } + if (has_rfx) { + rfx_model$sample_random_effect( rfx_dataset_train, outcome_train, - rfx_samples + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng ) } - if (sample_sigma2_global) { - current_sigma2 <- global_var_samples[forest_ind + 1] - global_model_config$update_global_error_variance( - current_sigma2 - ) + } + } + + # Run MCMC + if (num_burnin + num_mcmc > 0) { + for (chain_num in 1:num_chains) { + if (verbose) { + cat("Sampling chain", chain_num, "of", num_chains, "\n") } - } else if (has_prev_model) { - warmstart_index <- ifelse( - previous_model_decrement, - previous_model_warmstart_sample_num - chain_num + 1, - previous_model_warmstart_sample_num - ) - if (include_mean_forest) { - resetActiveForest( - active_forest_mean, - previous_forest_samples_mean, - warmstart_index - 1 - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE - ) - if ( - sample_sigma2_leaf && - (!is.null(previous_leaf_var_samples)) - ) { - leaf_scale_double <- previous_leaf_var_samples[ - warmstart_index - ] - current_leaf_scale <- as.matrix(leaf_scale_double) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale + if (num_gfr > 0) { + # Reset state of active_forest and forest_model based on a previous GFR sample + forest_ind <- num_gfr - chain_num + if (include_mean_forest) { + resetActiveForest( + active_forest_mean, + forest_samples_mean, + forest_ind ) - } - if (link_is_cloglog) { - # Restore ordinal labels corrupted by resetForestModel's - # residual adjustment (outcome stores category labels, not residuals) - outcome_train$update_data(resid_train) - # We can reset cutpoints from warm-start since cutpoints are retained - current_cutpoints <- previous_cloglog_cutpoint_samples[, - warmstart_index - ] - for (i in seq_along(current_cutpoints)) { - forest_dataset_train$set_auxiliary_data_value( - 2, - i - 1, - current_cutpoints[i] + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE + ) + if (sample_sigma2_leaf) { + leaf_scale_double <- leaf_scale_samples[forest_ind + 1] + current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale ) } - ordinal_sampler_update_cumsum_exp_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr - ) - # Re-predict from the reconstituted active forest - active_forest_preds <- active_forest_mean$predict( - forest_dataset_train - ) - for (i in 1:train_size) { - forest_dataset_train$set_auxiliary_data_value( - 1, - i - 1, - active_forest_preds[i] + if (link_is_cloglog) { + # Restore ordinal labels corrupted by resetForestModel's + # residual adjustment (outcome stores category labels, not residuals) + outcome_train$update_data(resid_train) + # We can reset cutpoints from warm-start since cutpoints are retained + current_cutpoints <- cloglog_cutpoint_samples[, forest_ind + 1] + for (i in seq_along(current_cutpoints)) { + forest_dataset_train$set_auxiliary_data_value( + 2, + i - 1, + current_cutpoints[i] + ) + } + ordinal_sampler_update_cumsum_exp_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr + ) + # Re-predict from the reconstituted active forest + active_forest_preds <- active_forest_mean$predict( + forest_dataset_train ) - # Latent variables must be reset to 0 and burnt in - forest_dataset_train$set_auxiliary_data_value(0, i - 1, 0.0) + for (i in 1:train_size) { + forest_dataset_train$set_auxiliary_data_value( + 1, + i - 1, + active_forest_preds[i] + ) + # Latent variables must be reset to 0 and burnt in + forest_dataset_train$set_auxiliary_data_value(0, i - 1, 0.0) + } } } - } - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - previous_forest_samples_variance, - warmstart_index - 1 - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - if (is.null(previous_rfx_samples)) { - warning( - "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" + if (include_variance_forest) { + resetActiveForest( + active_forest_variance, + forest_samples_variance, + forest_ind ) - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE ) - } else { + } + if (has_rfx) { resetRandomEffectsModel( rfx_model, - previous_rfx_samples, - warmstart_index - 1, + rfx_samples, + forest_ind, sigma_alpha_init ) resetRandomEffectsTracker( @@ -1959,552 +1947,673 @@ bart <- function( rfx_samples ) } - } - if (sample_sigma2_global) { - if (!is.null(previous_global_var_samples)) { - current_sigma2 <- previous_global_var_samples[ - warmstart_index - ] + if (sample_sigma2_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] global_model_config$update_global_error_variance( current_sigma2 ) } - } - } else { - if (include_mean_forest) { - resetActiveForest(active_forest_mean) - active_forest_mean$set_root_leaves( - init_values_mean_forest / num_trees_mean - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE + } else if (has_prev_model) { + warmstart_index <- ifelse( + previous_model_decrement, + previous_model_warmstart_sample_num - chain_num + 1, + previous_model_warmstart_sample_num ) - if (sample_sigma2_leaf) { - current_leaf_scale <- as.matrix(sigma2_leaf_init) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale + if (include_mean_forest) { + resetActiveForest( + active_forest_mean, + previous_forest_samples_mean, + warmstart_index - 1 ) - } - if (link_is_cloglog) { - # Restore ordinal labels corrupted by resetForestModel's - # residual adjustment (outcome stores category labels, not residuals) - outcome_train$update_data(resid_train) - # Reset all cloglog parameters to default values - for (i in 1:train_size) { - forest_dataset_train$set_auxiliary_data_value(0, i - 1, 0.0) - forest_dataset_train$set_auxiliary_data_value(1, i - 1, 0.0) - } - # Initialize log-scale cutpoints to 0 - initial_gamma <- rep(0.0, cloglog_num_categories - 1) - for (i in seq_along(initial_gamma)) { - forest_dataset_train$set_auxiliary_data_value( - 2, - i - 1, - initial_gamma[i] - ) - } - # Convert to cumulative exponentiated cutpoints directly in C++ - ordinal_sampler_update_cumsum_exp_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE ) - } - } - if (include_variance_forest) { - resetActiveForest(active_forest_variance) - active_forest_variance$set_root_leaves( - log(variance_forest_init) / num_trees_variance - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } - if (sample_sigma2_global) { - current_sigma2 <- sigma2_init - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } - for (i in (num_gfr + 1):num_samples) { - is_mcmc <- i > (num_gfr + num_burnin) - if (is_mcmc) { - mcmc_counter <- i - (num_gfr + num_burnin) - if (mcmc_counter %% keep_every == 0) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } else { - if (keep_burnin) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if (num_burnin > 0 && !is_mcmc) { if ( - ((i - num_gfr) %% 100 == 0) || - ((i - num_gfr) == num_burnin) + sample_sigma2_leaf && + (!is.null(previous_leaf_var_samples)) ) { - cat( - "Sampling", - i - num_gfr, - "out of", - num_burnin, - "BART burn-in draws; Chain number ", - chain_num, - "\n" + leaf_scale_double <- previous_leaf_var_samples[ + warmstart_index + ] + current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale ) } - } - if (num_mcmc > 0 && is_mcmc) { - raw_iter <- i - num_gfr - num_burnin - if ((raw_iter %% 100 == 0) || (i == num_samples)) { - if (keep_every == 1) { - cat( - "Sampling", - raw_iter, - "out of", - num_mcmc, - "BART MCMC draws; Chain number ", - chain_num, - "\n" + if (link_is_cloglog) { + # Restore ordinal labels corrupted by resetForestModel's + # residual adjustment (outcome stores category labels, not residuals) + outcome_train$update_data(resid_train) + # We can reset cutpoints from warm-start since cutpoints are retained + current_cutpoints <- previous_cloglog_cutpoint_samples[, + warmstart_index + ] + for (i in seq_along(current_cutpoints)) { + forest_dataset_train$set_auxiliary_data_value( + 2, + i - 1, + current_cutpoints[i] ) - } else { - cat( - "Sampling raw draw", - raw_iter, - "of", - num_actual_mcmc_iter, - "BART MCMC draws (thinning by", - keep_every, - ":", - raw_iter %/% keep_every, - "of", - num_mcmc, - "retained); Chain number ", - chain_num, - "\n" + } + ordinal_sampler_update_cumsum_exp_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr + ) + # Re-predict from the reconstituted active forest + active_forest_preds <- active_forest_mean$predict( + forest_dataset_train + ) + for (i in 1:train_size) { + forest_dataset_train$set_auxiliary_data_value( + 1, + i - 1, + active_forest_preds[i] ) + # Latent variables must be reset to 0 and burnt in + forest_dataset_train$set_auxiliary_data_value(0, i - 1, 0.0) } } } - } - - if (include_mean_forest) { - if (link_is_probit) { - # Sample latent probit variable, z | - - outcome_pred <- active_forest_mean$predict( - forest_dataset_train + if (include_variance_forest) { + resetActiveForest( + active_forest_variance, + previous_forest_samples_variance, + warmstart_index - 1 ) - if (has_rfx) { - rfx_pred <- rfx_model$predict( + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if (has_rfx) { + if (is.null(previous_rfx_samples)) { + warning( + "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" + ) + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, rfx_dataset_train, - rfx_tracker_train + outcome_train + ) + } else { + resetRandomEffectsModel( + rfx_model, + previous_rfx_samples, + warmstart_index - 1, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples ) - outcome_pred <- outcome_pred + rfx_pred } - eta_pred <- outcome_pred + y_bar_train - mu0 <- eta_pred[y_train == 0] - mu1 <- eta_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome: center z by y_bar_train before passing to forest - outcome_train$update_data( - resid_train - y_bar_train - outcome_pred + } + if (sample_sigma2_global) { + if (!is.null(previous_global_var_samples)) { + current_sigma2 <- previous_global_var_samples[ + warmstart_index + ] + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + } + } else { + if (include_mean_forest) { + resetActiveForest(active_forest_mean) + active_forest_mean$set_root_leaves( + init_values_mean_forest / num_trees_mean + ) + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE ) + if (sample_sigma2_leaf) { + current_leaf_scale <- as.matrix(sigma2_leaf_init) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) + } + if (link_is_cloglog) { + # Restore ordinal labels corrupted by resetForestModel's + # residual adjustment (outcome stores category labels, not residuals) + outcome_train$update_data(resid_train) + # Reset all cloglog parameters to default values + for (i in 1:train_size) { + forest_dataset_train$set_auxiliary_data_value(0, i - 1, 0.0) + forest_dataset_train$set_auxiliary_data_value(1, i - 1, 0.0) + } + # Initialize log-scale cutpoints to 0 + initial_gamma <- rep(0.0, cloglog_num_categories - 1) + for (i in seq_along(initial_gamma)) { + forest_dataset_train$set_auxiliary_data_value( + 2, + i - 1, + initial_gamma[i] + ) + } + # Convert to cumulative exponentiated cutpoints directly in C++ + ordinal_sampler_update_cumsum_exp_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr + ) + } + } + if (include_variance_forest) { + resetActiveForest(active_forest_variance) + active_forest_variance$set_root_leaves( + log(variance_forest_init) / num_trees_variance + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if (has_rfx) { + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) + } + if (sample_sigma2_global) { + current_sigma2 <- sigma2_init + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + } + for (i in (num_gfr + 1):num_samples) { + is_mcmc <- i > (num_gfr + num_burnin) + if (is_mcmc) { + mcmc_counter <- i - (num_gfr + num_burnin) + if (mcmc_counter %% keep_every == 0) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } + } else { + if (keep_burnin) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } - - forest_model_mean$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mean, - active_forest = active_forest_mean, - rng = rng, - forest_model_config = forest_model_config_mean, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling if (keep_sample) { - mean_forest_pred_train[, - sample_counter - ] <- forest_model_mean$get_cached_forest_predictions() + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if (num_burnin > 0 && !is_mcmc) { + if ( + ((i - num_gfr) %% 100 == 0) || + ((i - num_gfr) == num_burnin) + ) { + cat( + "Sampling", + i - num_gfr, + "out of", + num_burnin, + "BART burn-in draws; Chain number ", + chain_num, + "\n" + ) + } + } + if (num_mcmc > 0 && is_mcmc) { + raw_iter <- i - num_gfr - num_burnin + if ((raw_iter %% 100 == 0) || (i == num_samples)) { + if (keep_every == 1) { + cat( + "Sampling", + raw_iter, + "out of", + num_mcmc, + "BART MCMC draws; Chain number ", + chain_num, + "\n" + ) + } else { + cat( + "Sampling raw draw", + raw_iter, + "of", + num_actual_mcmc_iter, + "BART MCMC draws (thinning by", + keep_every, + ":", + raw_iter %/% keep_every, + "of", + num_mcmc, + "retained); Chain number ", + chain_num, + "\n" + ) + } + } + } } - # Additional Gibbs updates needed for the cloglog model - if (link_is_cloglog) { - # Update auxiliary data to current forest predictions - forest_pred_current <- forest_model_mean$get_cached_forest_predictions() - for (i in 1:train_size) { - forest_dataset_train$set_auxiliary_data_value( - 1, - i - 1, - forest_pred_current[i] + if (include_mean_forest) { + if (link_is_probit) { + # Sample latent probit variable, z | - + outcome_pred <- active_forest_mean$predict( + forest_dataset_train + ) + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + outcome_pred <- outcome_pred + rfx_pred + } + eta_pred <- outcome_pred + y_bar_train + mu0 <- eta_pred[y_train == 0] + mu1 <- eta_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome: center z by y_bar_train before passing to forest + outcome_train$update_data( + resid_train - y_bar_train - outcome_pred ) } - # Sample latent z_i's using truncated exponential - ordinal_sampler_update_latent_variables_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr, - outcome_train$data_ptr, - rng$rng_ptr + forest_model_mean$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mean, + active_forest = active_forest_mean, + rng = rng, + forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE ) - # Sample gamma parameters (cutpoints) - ordinal_sampler_update_gamma_params_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr, - outcome_train$data_ptr, - cloglog_forest_shape, - cloglog_forest_rate, - cloglog_cutpoint_0, - rng$rng_ptr - ) + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + mean_forest_pred_train[, + sample_counter + ] <- forest_model_mean$get_cached_forest_predictions() + } - # Update cumulative sum of exp(gamma) values - ordinal_sampler_update_cumsum_exp_cpp( - ordinal_sampler, - forest_dataset_train$data_ptr - ) + # Additional Gibbs updates needed for the cloglog model + if (link_is_cloglog) { + # Update auxiliary data to current forest predictions + forest_pred_current <- forest_model_mean$get_cached_forest_predictions() + for (i in 1:train_size) { + forest_dataset_train$set_auxiliary_data_value( + 1, + i - 1, + forest_pred_current[i] + ) + } - # Retain cutpoint draw - if (keep_sample) { - cloglog_cutpoints <- forest_dataset_train$get_auxiliary_data_vector( - 2 + # Sample latent z_i's using truncated exponential + ordinal_sampler_update_latent_variables_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr, + outcome_train$data_ptr, + rng$rng_ptr ) - cloglog_cutpoint_samples[, sample_counter] <- cloglog_cutpoints + + # Sample gamma parameters (cutpoints) + ordinal_sampler_update_gamma_params_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr, + outcome_train$data_ptr, + cloglog_forest_shape, + cloglog_forest_rate, + cloglog_cutpoint_0, + rng$rng_ptr + ) + + # Update cumulative sum of exp(gamma) values + ordinal_sampler_update_cumsum_exp_cpp( + ordinal_sampler, + forest_dataset_train$data_ptr + ) + + # Retain cutpoint draw + if (keep_sample) { + cloglog_cutpoints <- forest_dataset_train$get_auxiliary_data_vector( + 2 + ) + cloglog_cutpoint_samples[, sample_counter] <- cloglog_cutpoints + } } } - } - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = FALSE - ) + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE + ) - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - variance_forest_pred_train[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + variance_forest_pred_train[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() + } } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance( + current_sigma2 + ) } - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - if (sample_sigma2_leaf) { - leaf_scale_double <- sampleLeafVarianceOneIteration( - active_forest_mean, - rng, - a_leaf, - b_leaf - ) - current_leaf_scale <- as.matrix(leaf_scale_double) - if (keep_sample) { - leaf_scale_samples[sample_counter] <- leaf_scale_double + if (sample_sigma2_leaf) { + leaf_scale_double <- sampleLeafVarianceOneIteration( + active_forest_mean, + rng, + a_leaf, + b_leaf + ) + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) { + leaf_scale_samples[sample_counter] <- leaf_scale_double + } + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) + } + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) } - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) } } } - } - # Remove GFR samples if they are not to be retained - if ((!keep_gfr) && (num_gfr > 0)) { - for (i in 1:num_gfr) { + # Remove GFR samples if they are not to be retained + if ((!keep_gfr) && (num_gfr > 0)) { + for (i in 1:num_gfr) { + if (include_mean_forest) { + forest_samples_mean$delete_sample(0) + } + if (include_variance_forest) { + forest_samples_variance$delete_sample(0) + } + if (has_rfx) { + rfx_samples$delete_sample(0) + } + } if (include_mean_forest) { - forest_samples_mean$delete_sample(0) + mean_forest_pred_train <- mean_forest_pred_train[, + (num_gfr + 1):ncol(mean_forest_pred_train) + ] + if (link_is_cloglog) { + cloglog_cutpoint_samples <- cloglog_cutpoint_samples[, + (num_gfr + 1):ncol(cloglog_cutpoint_samples), + drop = FALSE + ] + } } if (include_variance_forest) { - forest_samples_variance$delete_sample(0) + variance_forest_pred_train <- variance_forest_pred_train[, + (num_gfr + 1):ncol(variance_forest_pred_train) + ] } - if (has_rfx) { - rfx_samples$delete_sample(0) + if (sample_sigma2_global) { + global_var_samples <- global_var_samples[ + (num_gfr + 1):length(global_var_samples) + ] + } + if (sample_sigma2_leaf) { + leaf_scale_samples <- leaf_scale_samples[ + (num_gfr + 1):length(leaf_scale_samples) + ] } + num_retained_samples <- num_retained_samples - num_gfr } + + # Mean forest predictions if (include_mean_forest) { - mean_forest_pred_train <- mean_forest_pred_train[, - (num_gfr + 1):ncol(mean_forest_pred_train) - ] - if (link_is_cloglog) { - cloglog_cutpoint_samples <- cloglog_cutpoint_samples[, - (num_gfr + 1):ncol(cloglog_cutpoint_samples), - drop = FALSE - ] + # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train + y_hat_train <- mean_forest_pred_train * y_std_train + y_bar_train + if (has_test) { + y_hat_test <- forest_samples_mean$predict(forest_dataset_test) * + y_std_train + + y_bar_train } } + + # Variance forest predictions if (include_variance_forest) { - variance_forest_pred_train <- variance_forest_pred_train[, - (num_gfr + 1):ncol(variance_forest_pred_train) - ] + # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) + sigma2_x_hat_train <- exp(variance_forest_pred_train) + if (has_test) { + sigma2_x_hat_test <- forest_samples_variance$predict( + forest_dataset_test + ) + } } - if (sample_sigma2_global) { - global_var_samples <- global_var_samples[ - (num_gfr + 1):length(global_var_samples) - ] + + # Random effects predictions + if (has_rfx) { + rfx_preds_train <- rfx_samples$predict( + rfx_group_ids_train, + rfx_basis_train + ) * + y_std_train + y_hat_train <- y_hat_train + rfx_preds_train } - if (sample_sigma2_leaf) { - leaf_scale_samples <- leaf_scale_samples[ - (num_gfr + 1):length(leaf_scale_samples) - ] + if ((has_rfx_test) && (has_test)) { + rfx_preds_test <- rfx_samples$predict( + rfx_group_ids_test, + rfx_basis_test + ) * + y_std_train + y_hat_test <- y_hat_test + rfx_preds_test } - num_retained_samples <- num_retained_samples - num_gfr - } - # Mean forest predictions - if (include_mean_forest) { - # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train - y_hat_train <- mean_forest_pred_train * y_std_train + y_bar_train - if (has_test) { - y_hat_test <- forest_samples_mean$predict(forest_dataset_test) * - y_std_train + - y_bar_train + # Global error variance + if (sample_sigma2_global) { + sigma2_global_samples <- global_var_samples * (y_std_train^2) } - } - # Variance forest predictions - if (include_variance_forest) { - # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) - sigma2_x_hat_train <- exp(variance_forest_pred_train) - if (has_test) { - sigma2_x_hat_test <- forest_samples_variance$predict( - forest_dataset_test - ) + # Leaf parameter variance + if (sample_sigma2_leaf) { + tau_samples <- leaf_scale_samples } - } - - # Random effects predictions - if (has_rfx) { - rfx_preds_train <- rfx_samples$predict( - rfx_group_ids_train, - rfx_basis_train - ) * - y_std_train - y_hat_train <- y_hat_train + rfx_preds_train - } - if ((has_rfx_test) && (has_test)) { - rfx_preds_test <- rfx_samples$predict( - rfx_group_ids_test, - rfx_basis_test - ) * - y_std_train - y_hat_test <- y_hat_test + rfx_preds_test - } - - # Global error variance - if (sample_sigma2_global) { - sigma2_global_samples <- global_var_samples * (y_std_train^2) - } - - # Leaf parameter variance - if (sample_sigma2_leaf) { - tau_samples <- leaf_scale_samples - } - # Rescale variance forest prediction by global sigma2 (sampled or constant) - if (include_variance_forest) { - if (sample_sigma2_global) { - sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { - sigma2_x_hat_train[, i] * sigma2_global_samples[i] - }) - if (has_test) { - sigma2_x_hat_test <- sapply( - 1:num_retained_samples, - function(i) { - sigma2_x_hat_test[, i] * sigma2_global_samples[i] - } - ) - } - } else { - sigma2_x_hat_train <- sigma2_x_hat_train * - sigma2_init * - y_std_train * - y_std_train - if (has_test) { - sigma2_x_hat_test <- sigma2_x_hat_test * + # Rescale variance forest prediction by global sigma2 (sampled or constant) + if (include_variance_forest) { + if (sample_sigma2_global) { + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { + sigma2_x_hat_train[, i] * sigma2_global_samples[i] + }) + if (has_test) { + sigma2_x_hat_test <- sapply( + 1:num_retained_samples, + function(i) { + sigma2_x_hat_test[, i] * sigma2_global_samples[i] + } + ) + } + } else { + sigma2_x_hat_train <- sigma2_x_hat_train * sigma2_init * y_std_train * y_std_train + if (has_test) { + sigma2_x_hat_test <- sigma2_x_hat_test * + sigma2_init * + y_std_train * + y_std_train + } } } - } - # Return results as a list - model_params <- list( - "sigma2_init" = sigma2_init, - "sigma2_leaf_init" = sigma2_leaf_init, - "a_global" = a_global, - "b_global" = b_global, - "a_leaf" = a_leaf, - "b_leaf" = b_leaf, - "a_forest" = a_forest, - "b_forest" = b_forest, - "outcome_mean" = y_bar_train, - "outcome_scale" = y_std_train, - "standardize" = standardize, - "leaf_dimension" = leaf_dimension, - "is_leaf_constant" = is_leaf_constant, - "leaf_regression" = leaf_regression, - "requires_basis" = requires_basis, - "num_covariates" = num_cov_orig, - "num_basis" = ifelse( - is.null(leaf_basis_train), - 0, - ncol(leaf_basis_train) - ), - "num_samples" = num_retained_samples, - "num_gfr" = num_gfr, - "num_burnin" = num_burnin, - "num_mcmc" = num_mcmc, - "keep_every" = keep_every, - "num_chains" = num_chains, - "has_basis" = !is.null(leaf_basis_train), - "has_rfx" = has_rfx, - "has_rfx_basis" = has_basis_rfx, - "num_rfx_basis" = num_basis_rfx, - "sample_sigma2_global" = sample_sigma2_global, - "sample_sigma2_leaf" = sample_sigma2_leaf, - "include_mean_forest" = include_mean_forest, - "include_variance_forest" = include_variance_forest, - "outcome_model" = outcome_model, - "probit_outcome_model" = probit_outcome_model, - "cloglog_num_categories" = ifelse( - link_is_cloglog, - cloglog_num_categories, - 0 - ), - "rfx_model_spec" = rfx_model_spec - ) - result <- list( - "model_params" = model_params, - "train_set_metadata" = X_train_metadata - ) - if (include_mean_forest) { - result[["mean_forests"]] = forest_samples_mean - result[["y_hat_train"]] = y_hat_train - if (has_test) { - result[["y_hat_test"]] = y_hat_test + # Return results as a list + model_params <- list( + "sigma2_init" = sigma2_init, + "sigma2_leaf_init" = sigma2_leaf_init, + "a_global" = a_global, + "b_global" = b_global, + "a_leaf" = a_leaf, + "b_leaf" = b_leaf, + "a_forest" = a_forest, + "b_forest" = b_forest, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "standardize" = standardize, + "leaf_dimension" = leaf_dimension, + "is_leaf_constant" = is_leaf_constant, + "leaf_regression" = leaf_regression, + "requires_basis" = requires_basis, + "num_covariates" = num_cov_orig, + "num_basis" = ifelse( + is.null(leaf_basis_train), + 0, + ncol(leaf_basis_train) + ), + "num_samples" = num_retained_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "keep_every" = keep_every, + "num_chains" = num_chains, + "has_basis" = !is.null(leaf_basis_train), + "has_rfx" = has_rfx, + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx, + "sample_sigma2_global" = sample_sigma2_global, + "sample_sigma2_leaf" = sample_sigma2_leaf, + "include_mean_forest" = include_mean_forest, + "include_variance_forest" = include_variance_forest, + "outcome_model" = outcome_model, + "probit_outcome_model" = probit_outcome_model, + "cloglog_num_categories" = ifelse( + link_is_cloglog, + cloglog_num_categories, + 0 + ), + "rfx_model_spec" = rfx_model_spec + ) + result <- list( + "model_params" = model_params, + "train_set_metadata" = X_train_metadata + ) + if (include_mean_forest) { + result[["mean_forests"]] = forest_samples_mean + result[["y_hat_train"]] = y_hat_train + if (has_test) { + result[["y_hat_test"]] = y_hat_test + } + if (link_is_cloglog && !outcome_is_binary) { + result[["cloglog_cutpoint_samples"]] = cloglog_cutpoint_samples + } } - if (link_is_cloglog && !outcome_is_binary) { - result[["cloglog_cutpoint_samples"]] = cloglog_cutpoint_samples + if (include_variance_forest) { + result[["variance_forests"]] = forest_samples_variance + result[["sigma2_x_hat_train"]] = sigma2_x_hat_train + if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test } - } - if (include_variance_forest) { - result[["variance_forests"]] = forest_samples_variance - result[["sigma2_x_hat_train"]] = sigma2_x_hat_train - if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test - } - if (sample_sigma2_global) { - result[["sigma2_global_samples"]] = sigma2_global_samples - } - if (sample_sigma2_leaf) { - result[["sigma2_leaf_samples"]] = tau_samples - } - if (has_rfx) { - result[["rfx_samples"]] = rfx_samples - result[["rfx_preds_train"]] = rfx_preds_train - result[["rfx_unique_group_ids"]] = levels(group_ids_factor) - } - if ((has_rfx_test) && (has_test)) { - result[["rfx_preds_test"]] = rfx_preds_test - } - class(result) <- "bartmodel" + if (sample_sigma2_global) { + result[["sigma2_global_samples"]] = sigma2_global_samples + } + if (sample_sigma2_leaf) { + result[["sigma2_leaf_samples"]] = tau_samples + } + if (has_rfx) { + result[["rfx_samples"]] = rfx_samples + result[["rfx_preds_train"]] = rfx_preds_train + result[["rfx_unique_group_ids"]] = levels(group_ids_factor) + } + if ((has_rfx_test) && (has_test)) { + result[["rfx_preds_test"]] = rfx_preds_test + } + class(result) <- "bartmodel" - # Clean up classes with external pointers to C++ data structures - if (include_mean_forest) { - rm(forest_model_mean) - } - if (include_variance_forest) { - rm(forest_model_variance) - } - rm(forest_dataset_train) - if (has_test) { - rm(forest_dataset_test) - } - if (has_rfx) { - rm(rfx_dataset_train, rfx_tracker_train, rfx_model) - } - rm(outcome_train) - rm(rng) + # Clean up classes with external pointers to C++ data structures + if (include_mean_forest) { + rm(forest_model_mean) + } + if (include_variance_forest) { + rm(forest_model_variance) + } + rm(forest_dataset_train) + if (has_test) { + rm(forest_dataset_test) + } + if (has_rfx) { + rm(rfx_dataset_train, rfx_tracker_train, rfx_model) + } + rm(outcome_train) + rm(rng) - # Restore global RNG state if user provided a random seed - if (custom_rng) { - if (has_existing_random_seed) { - .Random.seed <- original_global_seed - } else { - rm(".Random.seed", envir = .GlobalEnv) + # Restore global RNG state if user provided a random seed + if (custom_rng) { + if (has_existing_random_seed) { + .Random.seed <- original_global_seed + } else { + rm(".Random.seed", envir = .GlobalEnv) + } } } diff --git a/R/cpp11.R b/R/cpp11.R index c24913f0..d2707bf4 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -1,5 +1,9 @@ # Generated by cpp11: do not edit by hand +bart_sample_cpp <- function(X_train, y_train, X_test, n_train, n_test, p, basis_train, basis_test, basis_dim, obs_weights_train, obs_weights_test, rfx_group_ids_train, rfx_group_ids_test, rfx_basis_train, rfx_basis_test, rfx_num_groups, rfx_basis_dim, num_gfr, num_burnin, keep_every, num_mcmc, config_input) { + .Call(`_stochtree_bart_sample_cpp`, X_train, y_train, X_test, n_train, n_test, p, basis_train, basis_test, basis_dim, obs_weights_train, obs_weights_test, rfx_group_ids_train, rfx_group_ids_test, rfx_basis_train, rfx_basis_test, rfx_num_groups, rfx_basis_dim, num_gfr, num_burnin, keep_every, num_mcmc, config_input) +} + create_forest_dataset_cpp <- function() { .Call(`_stochtree_create_forest_dataset_cpp`) } diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 6772aec1..b9d25958 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -163,8 +163,8 @@ static void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, StochTree::BARTSamples samples; StochTree::BARTSampler sampler(samples, config, data); - sampler.run_gfr(samples, config, data, rng, num_gfr, true); - sampler.run_mcmc(samples, config, data, rng, 0, 1, num_mcmc); + sampler.run_gfr(samples, config, data, num_gfr, true); + sampler.run_mcmc(samples, config, data, 0, 1, num_mcmc); // Predictions are on latent scale (= raw + y_bar); compare to true latent Z. report_bart(samples, test.Z, "Scenario 1 (Probit BART)"); } diff --git a/debug/benchmark_cpp_vs_r_sampler.R b/debug/benchmark_cpp_vs_r_sampler.R new file mode 100644 index 00000000..848ae97a --- /dev/null +++ b/debug/benchmark_cpp_vs_r_sampler.R @@ -0,0 +1,151 @@ +## Benchmark: C++ sampler loop vs. R sampler loop +## Compares runtime and test-set RMSE across run_cpp = TRUE / FALSE in bart(). +## +## Usage: Rscript debug/benchmark_cpp_vs_r_sampler.R +## or source() from an interactive session after devtools::load_all('.') +library(stochtree) + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +set.seed(1234) + +n <- 2000 +p <- 10 +X <- matrix(runif(n * p), ncol = p) +f_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * + (-7.5) + + ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (-2.5) + + ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (2.5) + + ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (7.5)) +noise_sd <- 1 +y <- f_X + rnorm(n, 0, noise_sd) + +test_frac <- 0.2 +n_test <- round(test_frac * n) +n_train <- n - n_test +test_inds <- sort(sample(seq_len(n), n_test, replace = FALSE)) +train_inds <- setdiff(seq_len(n), test_inds) + +X_train <- X[train_inds, ] +X_test <- X[test_inds, ] +y_train <- y[train_inds] +y_test <- y[test_inds] +f_test <- f_X[test_inds] + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr <- 10 +num_mcmc <- 100 +num_trees <- 200 +n_reps <- 3 # repeated runs for stable timing + +cat(sprintf( + "n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_mcmc=%d reps=%d\n\n", + n_train, + n_test, + p, + num_trees, + num_gfr, + num_mcmc, + n_reps +)) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + RMSE +# --------------------------------------------------------------------------- +run_once <- function(run_cpp, seed) { + t0 <- proc.time() + m <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = num_mcmc, + mean_forest_params = list(num_trees = num_trees), + general_params = list(random_seed = seed), + run_cpp = run_cpp + ) + elapsed <- (proc.time() - t0)[["elapsed"]] + + yhat <- rowMeans(m$y_hat_test) + rmse <- sqrt(mean((yhat - y_test)^2)) + rmse_f <- sqrt(mean((yhat - f_test)^2)) + + list(elapsed = elapsed, rmse = rmse, rmse_f = rmse_f) +} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds <- 1000 + seq_len(n_reps) + +results_cpp <- vector("list", n_reps) +results_r <- vector("list", n_reps) + +cat("Running C++ sampler (run_cpp = TRUE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_cpp[[i]] <- run_once(run_cpp = TRUE, seed = seeds[i]) +} + +cat("\nRunning R sampler (run_cpp = FALSE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_r[[i]] <- run_once(run_cpp = FALSE, seed = seeds[i]) +} + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +summarise <- function(results, label) { + elapsed <- sapply(results, `[[`, "elapsed") + rmse <- sapply(results, `[[`, "rmse") + rmse_f <- sapply(results, `[[`, "rmse_f") + data.frame( + sampler = label, + elapsed_mean = mean(elapsed), + elapsed_sd = sd(elapsed), + rmse_mean = mean(rmse), + rmse_sd = sd(rmse), + rmse_f_mean = mean(rmse_f), + rmse_f_sd = sd(rmse_f), + row.names = NULL + ) +} + +res <- rbind( + summarise(results_cpp, "cpp (run_cpp=TRUE)"), + summarise(results_r, "R (run_cpp=FALSE)") +) + +cat("\n--- Results ---\n") +cat(sprintf( + "%-22s %10s %10s %12s %12s\n", + "Sampler", + "Time (s)", + " SD", + "RMSE (obs)", + "RMSE (f)" +)) +cat(strrep("-", 72), "\n") +for (i in seq_len(nrow(res))) { + cat(sprintf( + "%-22s %10.3f %10.3f %12.4f %12.4f\n", + res$sampler[i], + res$elapsed_mean[i], + res$elapsed_sd[i], + res$rmse_mean[i], + res$rmse_f_mean[i] + )) +} + +speedup <- res$elapsed_mean[2] / res$elapsed_mean[1] +cat(sprintf("\nSpeedup (R / C++): %.2fx\n", speedup)) +cat(sprintf( + "RMSE delta (cpp - R): obs=%.4f f=%.4f\n", + res$rmse_mean[1] - res$rmse_mean[2], + res$rmse_f_mean[1] - res$rmse_f_mean[2] +)) diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h index fb1b96bd..43112580 100644 --- a/include/stochtree/bart_sampler.h +++ b/include/stochtree/bart_sampler.h @@ -23,10 +23,10 @@ class BARTSampler { BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data); // Main entry point for running the BART sampler, which dispatches to GFR warmup and MCMC sampling functions - void run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_gfr, bool keep_gfr); + void run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, int num_gfr, bool keep_gfr); // Main entry point for running the BART sampler, which dispatches to GFR warmup and MCMC sampling functions - void run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_burnin, int keep_every, int num_mcmc); + void run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, int num_burnin, int keep_every, int num_mcmc); private: /*! Initialize state variables */ @@ -34,7 +34,7 @@ class BARTSampler { bool initialized_ = false; /*! Internal sample runner function */ - void RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, std::mt19937& rng, bool gfr, bool keep_sample); + void RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample); /*! Mean forest state */ std::unique_ptr mean_forest_; diff --git a/man/bart.Rd b/man/bart.Rd index a1794208..cd7152e1 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -23,7 +23,8 @@ bart( general_params = list(), mean_forest_params = list(), variance_forest_params = list(), - random_effects_params = list() + random_effects_params = list(), + run_cpp = TRUE ) } \arguments{ @@ -138,6 +139,8 @@ referred to internally in the C++ layer as "variance weights" (\code{var_weights \item \code{variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. \item \code{variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. }} + +\item{run_cpp}{Whether or not to run the core C++ sampler. This is exposed as an argument for testing purposes, but in general should be left as \code{TRUE}. If \code{FALSE}, the function will run the previous version of the BART sampler in which the core loop logic was implemented in R, with C++ calls for most computationally intensive steps.} } \value{ List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). diff --git a/src/Makevars.in b/src/Makevars.in index de1457c9..cd539774 100644 --- a/src/Makevars.in +++ b/src/Makevars.in @@ -23,7 +23,9 @@ PKG_LIBS = \ OBJECTS = \ forest.o \ kernel.o \ + R_bart.o \ R_data.o \ + bart_sampler.o \ R_random_effects.o \ R_utils.o \ sampler.o \ diff --git a/src/R_bart.cpp b/src/R_bart.cpp new file mode 100644 index 00000000..47f428ce --- /dev/null +++ b/src/R_bart.cpp @@ -0,0 +1,253 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void check_numeric(cpp11::sexp input, const char* input_name) { + if (TYPEOF(input) != REALSXP && !Rf_isInteger(input)) { + cpp11::stop("Parameter %s must be a numeric array (integer or floating point)", input_name); + } +} + +double* extract_numeric_pointer(cpp11::sexp input, const char* input_name, int& protect_count) { + if (input == R_NilValue) return nullptr; + check_numeric(input, input_name); + cpp11::sexp input_converted = PROTECT(Rf_coerceVector(input, REALSXP)); + protect_count++; + return REAL(input_converted); +} + +void check_integer(cpp11::sexp input, const char* input_name) { + if (!Rf_isInteger(input)) { + cpp11::stop("Parameter %s must be an integer array", input_name); + } +} + +int* extract_integer_pointer(cpp11::sexp input, const char* input_name, int& protect_count) { + if (input == R_NilValue) return nullptr; + check_integer(input, input_name); + protect_count++; + return INTEGER(input); +} + +template +T get_config_scalar_default(cpp11::list& config_list, const char* config_key, T default_value) { + cpp11::sexp val = config_list[config_key]; + if (Rf_isNull(val)) return default_value; + return cpp11::as_cpp(val); +} + +template <> +int get_config_scalar_default(cpp11::list& config_list, const char* config_key, int default_value) { + cpp11::sexp val = config_list[config_key]; + if (Rf_isNull(val)) return default_value; + return Rf_asInteger(val); +} + +StochTree::BARTConfig convert_list_to_config(cpp11::list config) { + StochTree::BARTConfig output; + + // Global model parameters + output.standardize_outcome = get_config_scalar_default(config, "standardize_outcome", true); + output.num_threads = get_config_scalar_default(config, "num_threads", 1); + output.cutpoint_grid_size = get_config_scalar_default(config, "cutpoint_grid_size", 100); + output.link_function = static_cast(get_config_scalar_default(config, "link_function", 0)); + output.outcome_type = static_cast(get_config_scalar_default(config, "outcome_type", 0)); + output.random_seed = get_config_scalar_default(config, "random_seed", 1); + + // Global error variance parameters + output.a_sigma2_global = get_config_scalar_default(config, "a_sigma2_global", 0.0); + output.b_sigma2_global = get_config_scalar_default(config, "b_sigma2_global", 0.0); + output.sigma2_global_init = get_config_scalar_default(config, "sigma2_global_init", 1.0); + output.sample_sigma2_global = get_config_scalar_default(config, "sample_sigma2_global", true); + + // Mean forest parameters + output.num_trees_mean = get_config_scalar_default(config, "num_trees_mean", 200); + output.alpha_mean = get_config_scalar_default(config, "alpha_mean", 0.95); + output.beta_mean = get_config_scalar_default(config, "beta_mean", 2.0); + output.min_samples_leaf_mean = get_config_scalar_default(config, "min_samples_leaf_mean", 5); + output.max_depth_mean = get_config_scalar_default(config, "max_depth_mean", -1); + output.leaf_constant_mean = get_config_scalar_default(config, "leaf_constant_mean", true); + output.leaf_dim_mean = get_config_scalar_default(config, "leaf_dim_mean", 1); + output.exponentiated_leaf_mean = get_config_scalar_default(config, "exponentiated_leaf_mean", true); + output.num_features_subsample_mean = get_config_scalar_default(config, "num_features_subsample_mean", 0); + output.a_sigma2_mean = get_config_scalar_default(config, "a_sigma2_mean", 3.0); + output.b_sigma2_mean = get_config_scalar_default(config, "b_sigma2_mean", -1.0); + output.sigma2_mean_init = get_config_scalar_default(config, "sigma2_mean_init", -1.0); + output.sample_sigma2_leaf_mean = get_config_scalar_default(config, "sample_sigma2_leaf_mean", false); + + // Variance forest parameters + output.num_trees_variance = get_config_scalar_default(config, "num_trees_variance", 0); + output.leaf_prior_calibration_param = get_config_scalar_default(config, "leaf_prior_calibration_param", 1.5); + output.shape_variance_forest = get_config_scalar_default(config, "shape_variance_forest", -1.0); + output.scale_variance_forest = get_config_scalar_default(config, "scale_variance_forest", -1.0); + output.alpha_variance = get_config_scalar_default(config, "alpha_variance", 0.5); + output.beta_variance = get_config_scalar_default(config, "beta_variance", 2.0); + output.min_samples_leaf_variance = get_config_scalar_default(config, "min_samples_leaf_variance", 5); + output.max_depth_variance = get_config_scalar_default(config, "max_depth_variance", -1); + output.leaf_constant_variance = get_config_scalar_default(config, "leaf_constant_variance", true); + output.leaf_dim_variance = get_config_scalar_default(config, "leaf_dim_variance", 1); + output.exponentiated_leaf_variance = get_config_scalar_default(config, "exponentiated_leaf_variance", true); + output.num_features_subsample_variance = get_config_scalar_default(config, "num_features_subsample_variance", 0); + + // Handle vector conversions separately + cpp11::sexp feature_type_sxp = config["feature_types"]; + if (!Rf_isNull(feature_type_sxp)) { + cpp11::integers feature_types_r_vec(feature_type_sxp); + for (auto i : feature_types_r_vec) { + output.feature_types.push_back(static_cast(i)); + } + } + cpp11::sexp sweep_update_indices_sxp = config["sweep_update_indices"]; + if (!Rf_isNull(sweep_update_indices_sxp)) { + cpp11::integers sweep_update_indices_r_vec(sweep_update_indices_sxp); + output.sweep_update_indices.assign(sweep_update_indices_r_vec.begin(), sweep_update_indices_r_vec.end()); + } + cpp11::sexp var_weights_mean_sxp = config["var_weights_mean"]; + if (!Rf_isNull(var_weights_mean_sxp)) { + cpp11::doubles var_weights_mean_r_vec(var_weights_mean_sxp); + output.var_weights_mean.assign(var_weights_mean_r_vec.begin(), var_weights_mean_r_vec.end()); + } + cpp11::sexp var_weights_variance_sxp = config["var_weights_variance"]; + if (!Rf_isNull(var_weights_variance_sxp)) { + cpp11::doubles var_weights_variance_r_vec(var_weights_variance_sxp); + output.var_weights_variance.assign(var_weights_variance_r_vec.begin(), var_weights_variance_r_vec.end()); + } +} + +cpp11::writable::list convert_bart_results_to_list(StochTree::BARTSamples& bart_samples) { + cpp11::writable::list output; + + // Pointers to forests + if (bart_samples.mean_forests.get() != nullptr) { + output["mean_forests"] = cpp11::external_pointer(bart_samples.mean_forests.release()); + } else { + output["mean_forests"] = R_NilValue; + } + + if (bart_samples.variance_forests.get() != nullptr) { + output["variance_forests"] = cpp11::external_pointer(bart_samples.variance_forests.release()); + } else { + output["variance_forests"] = R_NilValue; + } + + // Predictions + if (!bart_samples.mean_forest_predictions_train.empty()) { + output["mean_forest_predictions_train"] = cpp11::writable::doubles(bart_samples.mean_forest_predictions_train); + } else { + output["mean_forest_predictions_train"] = R_NilValue; + } + if (!bart_samples.variance_forest_predictions_train.empty()) { + output["variance_forest_predictions_train"] = cpp11::writable::doubles(bart_samples.variance_forest_predictions_train); + } else { + output["variance_forest_predictions_train"] = R_NilValue; + } + if (!bart_samples.mean_forest_predictions_test.empty()) { + output["mean_forest_predictions_test"] = cpp11::writable::doubles(bart_samples.mean_forest_predictions_test); + } else { + output["mean_forest_predictions_test"] = R_NilValue; + } + if (!bart_samples.variance_forest_predictions_test.empty()) { + output["variance_forest_predictions_test"] = cpp11::writable::doubles(bart_samples.variance_forest_predictions_test); + } else { + output["variance_forest_predictions_test"] = R_NilValue; + } + + // Parameter samples + if (!bart_samples.global_error_variance_samples.empty()) { + output["global_error_variance_samples"] = cpp11::writable::doubles(bart_samples.global_error_variance_samples); + } else { + output["global_error_variance_samples"] = R_NilValue; + } + if (!bart_samples.leaf_scale_samples.empty()) { + output["leaf_scale_samples"] = cpp11::writable::doubles(bart_samples.leaf_scale_samples); + } else { + output["leaf_scale_samples"] = R_NilValue; + } + + return output; +} + +[[cpp11::register]] +cpp11::writable::list bart_sample_cpp( + cpp11::sexp X_train, + cpp11::sexp y_train, + cpp11::sexp X_test, + int n_train, + int n_test, + int p, + cpp11::sexp basis_train, + cpp11::sexp basis_test, + int basis_dim, + cpp11::sexp obs_weights_train, + cpp11::sexp obs_weights_test, + cpp11::sexp rfx_group_ids_train, + cpp11::sexp rfx_group_ids_test, + cpp11::sexp rfx_basis_train, + cpp11::sexp rfx_basis_test, + int rfx_num_groups, + int rfx_basis_dim, + int num_gfr, + int num_burnin, + int keep_every, + int num_mcmc, + cpp11::list config_input) { + // Create smart pointer to outcome object + StochTree::BARTSamples results_raw = StochTree::BARTSamples(); + + // Extract pointers to raw data + int protect_count = 0; + double* X_train_ptr = extract_numeric_pointer(X_train, "X_train", protect_count); + double* y_train_ptr = extract_numeric_pointer(y_train, "y_train", protect_count); + double* X_test_ptr = extract_numeric_pointer(X_test, "X_test", protect_count); + double* basis_train_ptr = extract_numeric_pointer(basis_train, "basis_train", protect_count); + double* basis_test_ptr = extract_numeric_pointer(basis_test, "basis_test", protect_count); + double* obs_weights_train_ptr = extract_numeric_pointer(obs_weights_train, "obs_weights_train", protect_count); + double* obs_weights_test_ptr = extract_numeric_pointer(obs_weights_test, "obs_weights_test", protect_count); + int* rfx_group_ids_train_ptr = extract_integer_pointer(rfx_group_ids_train, "rfx_group_ids_train", protect_count); + int* rfx_group_ids_test_ptr = extract_integer_pointer(rfx_group_ids_test, "rfx_group_ids_test", protect_count); + double* rfx_basis_train_ptr = extract_numeric_pointer(rfx_basis_train, "rfx_basis_train", protect_count); + double* rfx_basis_test_ptr = extract_numeric_pointer(rfx_basis_test, "rfx_basis_test", protect_count); + + // Load the BARTData struct + // Consider reading directly from the R objects or at least checking for matches with the R object dimensions) + StochTree::BARTData data; + data.X_train = X_train_ptr; + data.y_train = y_train_ptr; + data.X_test = X_test_ptr; + data.n_train = n_train; + data.p = p; + data.n_test = n_test; + data.basis_train = basis_train_ptr; + data.basis_test = basis_test_ptr; + data.basis_dim = basis_dim; + data.obs_weights_train = obs_weights_train_ptr; + data.obs_weights_test = obs_weights_test_ptr; + data.rfx_group_ids_train = rfx_group_ids_train_ptr; + data.rfx_group_ids_test = rfx_group_ids_test_ptr; + data.rfx_basis_train = rfx_basis_train_ptr; + data.rfx_basis_test = rfx_basis_test_ptr; + data.rfx_num_groups = rfx_num_groups; + data.rfx_basis_dim = rfx_basis_dim; + + // Create the BARTConfig object + StochTree::BARTConfig config = convert_list_to_config(config_input); + + // Initialize a BART sampler + StochTree::BARTSampler bart_sampler(results_raw, config, data); + + // Run the sampler + bart_sampler.run_gfr(results_raw, config, data, num_gfr, true); + bart_sampler.run_mcmc(results_raw, config, data, num_burnin, keep_every, num_mcmc); + + // Unprotect protected R objects + UNPROTECT(protect_count); + + // Release management of the pointer to R session + return convert_bart_results_to_list(results_raw); +} diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 75e90e81..ac420596 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -2,14 +2,13 @@ #include #include #include +#include #include #include #include #include #include -#include #include -#include "stochtree/leaf_model.h" namespace StochTree { @@ -142,16 +141,16 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART initialized_ = true; } -void BARTSampler::run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_gfr, bool keep_gfr) { +void BARTSampler::run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, int num_gfr, bool keep_gfr) { // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); std::unique_ptr variance_leaf_model_ptr = std::make_unique(config.shape_variance_forest, config.scale_variance_forest); for (int i = 0; i < num_gfr; i++) { - RunOneIteration(samples, config, data, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), rng, /*gfr=*/true, /*keep_sample=*/keep_gfr); + RunOneIteration(samples, config, data, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/true, /*keep_sample=*/keep_gfr); } } -void BARTSampler::run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, std::mt19937& rng, int num_burnin, int keep_every, int num_mcmc) { +void BARTSampler::run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, int num_burnin, int keep_every, int num_mcmc) { std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); std::unique_ptr variance_leaf_model_ptr = std::make_unique(config.shape_variance_forest, config.scale_variance_forest); bool keep_forest = false; @@ -160,16 +159,16 @@ void BARTSampler::run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& d keep_forest = true; else keep_forest = false; - RunOneIteration(samples, config, data, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), rng, /*gfr=*/false, /*keep_sample=*/keep_forest); + RunOneIteration(samples, config, data, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/false, /*keep_sample=*/keep_forest); } } -void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, std::mt19937& rng, bool gfr, bool keep_sample) { +void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample) { if (has_mean_forest_) { if (gfr) { GFRSampleOneIter( *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, - *forest_dataset_, *residual_, *tree_prior_mean_, rng, + *forest_dataset_, *residual_, *tree_prior_mean_, rng_, config.var_weights_mean, config.sweep_update_indices, global_variance_, config.feature_types, config.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, @@ -177,7 +176,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART } else { MCMCSampleOneIter( *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, - *forest_dataset_, *residual_, *tree_prior_mean_, rng, + *forest_dataset_, *residual_, *tree_prior_mean_, rng_, config.var_weights_mean, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, /*num_threads=*/config.num_threads); @@ -188,7 +187,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART if (gfr) { GFRSampleOneIter( *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, - *forest_dataset_, *residual_, *tree_prior_variance_, rng, + *forest_dataset_, *residual_, *tree_prior_variance_, rng_, config.var_weights_variance, config.sweep_update_indices, global_variance_, config.feature_types, config.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, @@ -196,7 +195,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART } else { MCMCSampleOneIter( *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, - *forest_dataset_, *residual_, *tree_prior_variance_, rng, + *forest_dataset_, *residual_, *tree_prior_variance_, rng_, config.var_weights_variance, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, /*num_threads=*/config.num_threads); diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 3626eb0b..fdea3b9e 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -5,6 +5,13 @@ #include "cpp11/declarations.hpp" #include +// R_bart.cpp +cpp11::writable::list bart_sample_cpp(cpp11::sexp X_train, cpp11::sexp y_train, cpp11::sexp X_test, int n_train, int n_test, int p, cpp11::sexp basis_train, cpp11::sexp basis_test, int basis_dim, cpp11::sexp obs_weights_train, cpp11::sexp obs_weights_test, cpp11::sexp rfx_group_ids_train, cpp11::sexp rfx_group_ids_test, cpp11::sexp rfx_basis_train, cpp11::sexp rfx_basis_test, int rfx_num_groups, int rfx_basis_dim, int num_gfr, int num_burnin, int keep_every, int num_mcmc, cpp11::list config_input); +extern "C" SEXP _stochtree_bart_sample_cpp(SEXP X_train, SEXP y_train, SEXP X_test, SEXP n_train, SEXP n_test, SEXP p, SEXP basis_train, SEXP basis_test, SEXP basis_dim, SEXP obs_weights_train, SEXP obs_weights_test, SEXP rfx_group_ids_train, SEXP rfx_group_ids_test, SEXP rfx_basis_train, SEXP rfx_basis_test, SEXP rfx_num_groups, SEXP rfx_basis_dim, SEXP num_gfr, SEXP num_burnin, SEXP keep_every, SEXP num_mcmc, SEXP config_input) { + BEGIN_CPP11 + return cpp11::as_sexp(bart_sample_cpp(cpp11::as_cpp>(X_train), cpp11::as_cpp>(y_train), cpp11::as_cpp>(X_test), cpp11::as_cpp>(n_train), cpp11::as_cpp>(n_test), cpp11::as_cpp>(p), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(basis_test), cpp11::as_cpp>(basis_dim), cpp11::as_cpp>(obs_weights_train), cpp11::as_cpp>(obs_weights_test), cpp11::as_cpp>(rfx_group_ids_train), cpp11::as_cpp>(rfx_group_ids_test), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_basis_test), cpp11::as_cpp>(rfx_num_groups), cpp11::as_cpp>(rfx_basis_dim), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(keep_every), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(config_input))); + END_CPP11 +} // R_data.cpp cpp11::external_pointer create_forest_dataset_cpp(); extern "C" SEXP _stochtree_create_forest_dataset_cpp() { @@ -1694,6 +1701,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, {"_stochtree_average_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_average_max_depth_active_forest_cpp, 1}, {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, + {"_stochtree_bart_sample_cpp", (DL_FUNC) &_stochtree_bart_sample_cpp, 22}, {"_stochtree_combine_forests_forest_container_cpp", (DL_FUNC) &_stochtree_combine_forests_forest_container_cpp, 2}, {"_stochtree_compute_leaf_indices_cpp", (DL_FUNC) &_stochtree_compute_leaf_indices_cpp, 3}, {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, diff --git a/src/stochtree_types.h b/src/stochtree_types.h index 0e17038f..03b642fe 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -1,3 +1,5 @@ +#include +#include #include #include #include From 78d1c2590762bc338aaab3c07034d752dfe041ad Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 10 Apr 2026 19:49:12 -0400 Subject: [PATCH 28/64] Mostly-working R implementation --- R/bart.R | 35 ++++++++- debug/benchmark_cpp_vs_r_sampler.R | 2 +- src/R_bart.cpp | 114 +++++++++++++++-------------- 3 files changed, 94 insertions(+), 57 deletions(-) diff --git a/R/bart.R b/R/bart.R index e92a0c48..2cb11776 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1136,7 +1136,7 @@ bart <- function( "leaf_dim_variance" = 1, "exponentiated_leaf_variance" = TRUE, "num_features_subsample_variance" = num_features_subsample_variance, - "feature_types" = feature_types, + "feature_types" = as.integer(feature_types), "sweep_update_indices" = 0:(ncol(X_train) - 1), "var_weights_mean" = variable_weights_mean, "var_weights_variance" = variable_weights_variance @@ -1151,7 +1151,11 @@ bart <- function( p = ncol(X_train), basis_train = if (exists("leaf_basis_train")) leaf_basis_train else NULL, basis_test = if (exists("leaf_basis_test")) leaf_basis_test else NULL, - basis_dim = if (!is.null(leaf_basis_train)) ncol(leaf_basis_train) else 0L, + basis_dim = if (!is.null(leaf_basis_train)) { + ncol(leaf_basis_train) + } else { + 0L + }, obs_weights_train = if (exists("obs_weights_train")) { obs_weights_train } else { @@ -1178,7 +1182,11 @@ bart <- function( NULL }, rfx_basis_test = if (exists("rfx_basis_test")) rfx_basis_test else NULL, - rfx_num_groups = if (exists("num_rfx_groups")) as.integer(num_rfx_groups) else 0L, + rfx_num_groups = if (exists("num_rfx_groups")) { + as.integer(num_rfx_groups) + } else { + 0L + }, rfx_basis_dim = as.integer(num_basis_rfx), num_gfr = as.integer(num_gfr), num_burnin = as.integer(num_burnin), @@ -1187,6 +1195,27 @@ bart <- function( config_input = bart_config ) result <- bart_results + # TODO: store num_samples in the result list + if (!is.null(result['mean_forest_predictions_train'])) { + dim(result[['mean_forest_predictions_train']]) <- c( + result[["num_train"]], + result[["num_samples"]] + ) + y_hat_train_raw <- result[["mean_forest_predictions_train"]] + result[["y_hat_train"]] <- y_hat_train_raw * + result[["y_std"]] + + result[["y_bar"]] + } + if (!is.null(result['mean_forest_predictions_test'])) { + dim(result[['mean_forest_predictions_test']]) <- c( + result[["num_test"]], + result[["num_samples"]] + ) + y_hat_test_raw <- result[["mean_forest_predictions_test"]] + result[["y_hat_test"]] <- y_hat_test_raw * + result[["y_std"]] + + result[["y_bar"]] + } class(result) <- "bartmodel" } else { # Set a function-scoped RNG if user provided a random seed diff --git a/debug/benchmark_cpp_vs_r_sampler.R b/debug/benchmark_cpp_vs_r_sampler.R index 848ae97a..09e637db 100644 --- a/debug/benchmark_cpp_vs_r_sampler.R +++ b/debug/benchmark_cpp_vs_r_sampler.R @@ -55,7 +55,7 @@ cat(sprintf( # --------------------------------------------------------------------------- # Helper: run one configuration and return timing + RMSE # --------------------------------------------------------------------------- -run_once <- function(run_cpp, seed) { +run_once <- function(run_cpp, seed = -1) { t0 <- proc.time() m <- bart( X_train = X_train, diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 47f428ce..e27189cc 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -96,79 +96,88 @@ StochTree::BARTConfig convert_list_to_config(cpp11::list config) { output.num_features_subsample_variance = get_config_scalar_default(config, "num_features_subsample_variance", 0); // Handle vector conversions separately - cpp11::sexp feature_type_sxp = config["feature_types"]; - if (!Rf_isNull(feature_type_sxp)) { - cpp11::integers feature_types_r_vec(feature_type_sxp); + SEXP feature_type_raw = static_cast(config["feature_types"]); + if (!Rf_isNull(feature_type_raw)) { + cpp11::integers feature_types_r_vec(feature_type_raw); for (auto i : feature_types_r_vec) { output.feature_types.push_back(static_cast(i)); } } - cpp11::sexp sweep_update_indices_sxp = config["sweep_update_indices"]; - if (!Rf_isNull(sweep_update_indices_sxp)) { - cpp11::integers sweep_update_indices_r_vec(sweep_update_indices_sxp); + SEXP sweep_update_indices_raw = static_cast(config["sweep_update_indices"]); + if (!Rf_isNull(sweep_update_indices_raw)) { + cpp11::integers sweep_update_indices_r_vec(sweep_update_indices_raw); output.sweep_update_indices.assign(sweep_update_indices_r_vec.begin(), sweep_update_indices_r_vec.end()); } - cpp11::sexp var_weights_mean_sxp = config["var_weights_mean"]; - if (!Rf_isNull(var_weights_mean_sxp)) { - cpp11::doubles var_weights_mean_r_vec(var_weights_mean_sxp); + SEXP var_weights_mean_raw = static_cast(config["var_weights_mean"]); + if (!Rf_isNull(var_weights_mean_raw)) { + cpp11::doubles var_weights_mean_r_vec(var_weights_mean_raw); output.var_weights_mean.assign(var_weights_mean_r_vec.begin(), var_weights_mean_r_vec.end()); } - cpp11::sexp var_weights_variance_sxp = config["var_weights_variance"]; - if (!Rf_isNull(var_weights_variance_sxp)) { - cpp11::doubles var_weights_variance_r_vec(var_weights_variance_sxp); + SEXP var_weights_variance_raw = static_cast(config["var_weights_variance"]); + if (!Rf_isNull(var_weights_variance_raw)) { + cpp11::doubles var_weights_variance_r_vec(var_weights_variance_raw); output.var_weights_variance.assign(var_weights_variance_r_vec.begin(), var_weights_variance_r_vec.end()); } + return output; } cpp11::writable::list convert_bart_results_to_list(StochTree::BARTSamples& bart_samples) { cpp11::writable::list output; // Pointers to forests - if (bart_samples.mean_forests.get() != nullptr) { - output["mean_forests"] = cpp11::external_pointer(bart_samples.mean_forests.release()); - } else { - output["mean_forests"] = R_NilValue; - } + SEXP mean_forests_sexp = (bart_samples.mean_forests.get() != nullptr) + ? static_cast(cpp11::external_pointer(bart_samples.mean_forests.release())) + : R_NilValue; + output.push_back(cpp11::named_arg("mean_forests") = mean_forests_sexp); - if (bart_samples.variance_forests.get() != nullptr) { - output["variance_forests"] = cpp11::external_pointer(bart_samples.variance_forests.release()); - } else { - output["variance_forests"] = R_NilValue; - } + SEXP variance_forests_sexp = (bart_samples.variance_forests.get() != nullptr) + ? static_cast(cpp11::external_pointer(bart_samples.variance_forests.release())) + : R_NilValue; + output.push_back(cpp11::named_arg("variance_forests") = variance_forests_sexp); // Predictions - if (!bart_samples.mean_forest_predictions_train.empty()) { - output["mean_forest_predictions_train"] = cpp11::writable::doubles(bart_samples.mean_forest_predictions_train); - } else { - output["mean_forest_predictions_train"] = R_NilValue; - } - if (!bart_samples.variance_forest_predictions_train.empty()) { - output["variance_forest_predictions_train"] = cpp11::writable::doubles(bart_samples.variance_forest_predictions_train); - } else { - output["variance_forest_predictions_train"] = R_NilValue; - } - if (!bart_samples.mean_forest_predictions_test.empty()) { - output["mean_forest_predictions_test"] = cpp11::writable::doubles(bart_samples.mean_forest_predictions_test); - } else { - output["mean_forest_predictions_test"] = R_NilValue; - } - if (!bart_samples.variance_forest_predictions_test.empty()) { - output["variance_forest_predictions_test"] = cpp11::writable::doubles(bart_samples.variance_forest_predictions_test); - } else { - output["variance_forest_predictions_test"] = R_NilValue; - } + SEXP mean_preds_train_sexp = !bart_samples.mean_forest_predictions_train.empty() + ? static_cast(cpp11::writable::doubles(bart_samples.mean_forest_predictions_train.begin(), bart_samples.mean_forest_predictions_train.end())) + : R_NilValue; + output.push_back(cpp11::named_arg("mean_forest_predictions_train") = mean_preds_train_sexp); + + SEXP var_preds_train_sexp = !bart_samples.variance_forest_predictions_train.empty() + ? static_cast(cpp11::writable::doubles(bart_samples.variance_forest_predictions_train.begin(), bart_samples.variance_forest_predictions_train.end())) + : R_NilValue; + output.push_back(cpp11::named_arg("variance_forest_predictions_train") = var_preds_train_sexp); + + SEXP mean_preds_test_sexp = !bart_samples.mean_forest_predictions_test.empty() + ? static_cast(cpp11::writable::doubles(bart_samples.mean_forest_predictions_test.begin(), bart_samples.mean_forest_predictions_test.end())) + : R_NilValue; + output.push_back(cpp11::named_arg("mean_forest_predictions_test") = mean_preds_test_sexp); + + SEXP var_preds_test_sexp = !bart_samples.variance_forest_predictions_test.empty() + ? static_cast(cpp11::writable::doubles(bart_samples.variance_forest_predictions_test.begin(), bart_samples.variance_forest_predictions_test.end())) + : R_NilValue; + output.push_back(cpp11::named_arg("variance_forest_predictions_test") = var_preds_test_sexp); // Parameter samples - if (!bart_samples.global_error_variance_samples.empty()) { - output["global_error_variance_samples"] = cpp11::writable::doubles(bart_samples.global_error_variance_samples); - } else { - output["global_error_variance_samples"] = R_NilValue; - } - if (!bart_samples.leaf_scale_samples.empty()) { - output["leaf_scale_samples"] = cpp11::writable::doubles(bart_samples.leaf_scale_samples); - } else { - output["leaf_scale_samples"] = R_NilValue; - } + SEXP global_var_sexp = !bart_samples.global_error_variance_samples.empty() + ? static_cast(cpp11::writable::doubles(bart_samples.global_error_variance_samples.begin(), bart_samples.global_error_variance_samples.end())) + : R_NilValue; + output.push_back(cpp11::named_arg("global_error_variance_samples") = global_var_sexp); + + SEXP leaf_scale_sexp = !bart_samples.leaf_scale_samples.empty() + ? static_cast(cpp11::writable::doubles(bart_samples.leaf_scale_samples.begin(), bart_samples.leaf_scale_samples.end())) + : R_NilValue; + output.push_back(cpp11::named_arg("leaf_scale_samples") = leaf_scale_sexp); + + // Sample metadata + double y_bar_sexp = bart_samples.y_bar; + output.push_back(cpp11::named_arg("y_bar") = y_bar_sexp); + double y_std_sexp = bart_samples.y_std; + output.push_back(cpp11::named_arg("y_std") = y_std_sexp); + int num_samples_sexp = bart_samples.num_samples; + output.push_back(cpp11::named_arg("num_samples") = num_samples_sexp); + int num_train_sexp = bart_samples.num_train; + output.push_back(cpp11::named_arg("num_train") = num_train_sexp); + int num_test_sexp = bart_samples.num_test; + output.push_back(cpp11::named_arg("num_test") = num_test_sexp); return output; } @@ -248,6 +257,5 @@ cpp11::writable::list bart_sample_cpp( // Unprotect protected R objects UNPROTECT(protect_count); - // Release management of the pointer to R session return convert_bart_results_to_list(results_raw); } From ec0e4c94400cd7d56bda0e530e7bf304b0078c73 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 10 Apr 2026 20:24:22 -0400 Subject: [PATCH 29/64] Fixed sigma2_leaf bug --- src/bart_sampler.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index ac420596..8fedab06 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -85,7 +85,7 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART config.sigma2_mean_init = y_var / config.num_trees_mean; } } - if (sample_sigma2_leaf_) { + if (config.sample_sigma2_leaf_mean) { if (config.b_sigma2_mean <= 0.0) { if (config.link_function == LinkFunction::Probit) { config.b_sigma2_mean = 1.0 / (2 * config.num_trees_mean); From 92b4485c2787dd6c273ed2ec8835350b0c8ceeec Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 07:50:31 -0400 Subject: [PATCH 30/64] Fixed sweep_update_indices bug --- R/bart.R | 11 +++++- debug/bart_debug.cpp | 12 +++---- debug/benchmark_cpp_vs_r_sampler.R | 18 +++++++--- include/stochtree/bart.h | 57 +++++++++++++++--------------- src/R_bart.cpp | 13 ++++--- src/bart_sampler.cpp | 8 ++--- 6 files changed, 72 insertions(+), 47 deletions(-) diff --git a/R/bart.R b/R/bart.R index 2cb11776..ebf5a99c 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1137,7 +1137,16 @@ bart <- function( "exponentiated_leaf_variance" = TRUE, "num_features_subsample_variance" = num_features_subsample_variance, "feature_types" = as.integer(feature_types), - "sweep_update_indices" = 0:(ncol(X_train) - 1), + "sweep_update_indices_mean" = if (num_trees_mean > 0) { + 0:(num_trees_mean - 1) + } else { + NULL + }, + "sweep_update_indices_variance" = if (num_trees_variance > 0) { + 0:(num_trees_variance - 1) + } else { + NULL + }, "var_weights_mean" = variable_weights_mean, "var_weights_variance" = variable_weights_variance ) diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index b9d25958..b8122cfa 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -126,13 +126,13 @@ static void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, config.sample_sigma2_global = true; config.var_weights_mean = std::vector(p, 1.0 / p); config.feature_types = std::vector(p, StochTree::FeatureType::kNumeric); - config.sweep_update_indices = std::vector(num_trees, 0); - std::iota(config.sweep_update_indices.begin(), config.sweep_update_indices.end(), 0); + config.sweep_update_indices_mean = std::vector(num_trees, 0); + std::iota(config.sweep_update_indices_mean.begin(), config.sweep_update_indices_mean.end(), 0); StochTree::BARTSamples samples; StochTree::BARTSampler sampler(samples, config, data); - sampler.run_gfr(samples, config, data, rng, num_gfr, true); - sampler.run_mcmc(samples, config, data, rng, 0, 1, num_mcmc); + sampler.run_gfr(samples, config, data, num_gfr, true); + sampler.run_mcmc(samples, config, data, 0, 1, num_mcmc); report_bart(samples, test.y, "Scenario 0 (Homoskedastic BART)"); } @@ -158,8 +158,8 @@ static void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, config.sample_sigma2_global = false; config.var_weights_mean = std::vector(p, 1.0 / p); config.feature_types = std::vector(p, StochTree::FeatureType::kNumeric); - config.sweep_update_indices = std::vector(num_trees, 0); - std::iota(config.sweep_update_indices.begin(), config.sweep_update_indices.end(), 0); + config.sweep_update_indices_mean = std::vector(num_trees, 0); + std::iota(config.sweep_update_indices_mean.begin(), config.sweep_update_indices_mean.end(), 0); StochTree::BARTSamples samples; StochTree::BARTSampler sampler(samples, config, data); diff --git a/debug/benchmark_cpp_vs_r_sampler.R b/debug/benchmark_cpp_vs_r_sampler.R index 09e637db..4b6c05c3 100644 --- a/debug/benchmark_cpp_vs_r_sampler.R +++ b/debug/benchmark_cpp_vs_r_sampler.R @@ -10,7 +10,7 @@ library(stochtree) # --------------------------------------------------------------------------- set.seed(1234) -n <- 2000 +n <- 10000 p <- 10 X <- matrix(runif(n * p), ncol = p) f_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * @@ -55,7 +55,7 @@ cat(sprintf( # --------------------------------------------------------------------------- # Helper: run one configuration and return timing + RMSE # --------------------------------------------------------------------------- -run_once <- function(run_cpp, seed = -1) { +run_once <- function(run_cpp, num_gfr, num_mcmc, seed = -1) { t0 <- proc.time() m <- bart( X_train = X_train, @@ -88,13 +88,23 @@ results_r <- vector("list", n_reps) cat("Running C++ sampler (run_cpp = TRUE)...\n") for (i in seq_len(n_reps)) { cat(sprintf(" rep %d/%d\n", i, n_reps)) - results_cpp[[i]] <- run_once(run_cpp = TRUE, seed = seeds[i]) + results_cpp[[i]] <- run_once( + run_cpp = TRUE, + num_gfr = num_gfr, + num_mcmc = num_mcmc, + seed = seeds[i] + ) } cat("\nRunning R sampler (run_cpp = FALSE)...\n") for (i in seq_len(n_reps)) { cat(sprintf(" rep %d/%d\n", i, n_reps)) - results_r[[i]] <- run_once(run_cpp = FALSE, seed = seeds[i]) + results_r[[i]] <- run_once( + run_cpp = FALSE, + num_gfr = num_gfr, + num_mcmc = num_mcmc, + seed = seeds[i] + ) } # --------------------------------------------------------------------------- diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index 5548bef2..31562250 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -61,7 +61,6 @@ struct BARTConfig { int num_threads = 1; // number of threads to use for sampling int cutpoint_grid_size = 100; // number of cutpoints to consider for each covariate when sampling splits std::vector feature_types; // feature types for each covariate (should be same length as number of covariates in the dataset), where 0 = continuous, 1 = categorical - std::vector sweep_update_indices; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) LinkFunction link_function = LinkFunction::Identity; // link function to use (Identity, Probit, Cloglog) OutcomeType outcome_type = OutcomeType::Continuous; // type of the outcome variable (Continuous, Binary, Ordinal) int random_seed = -1; // random seed for reproducibility (if negative, a random seed will be generated) @@ -73,35 +72,37 @@ struct BARTConfig { bool sample_sigma2_global = true; // whether to sample global error variance (if false, it will be fixed at sigma2_global_init) // Mean forest parameters - int num_trees_mean = 200; // number of trees in the mean forest - double alpha_mean = 0.95; // alpha parameter for mean forest tree prior - double beta_mean = 2.0; // beta parameter for mean forest tree prior - int min_samples_leaf_mean = 5; // minimum number of samples per leaf for mean forest - int max_depth_mean = -1; // maximum depth for mean forest trees (-1 means no maximum) - bool leaf_constant_mean = true; // whether to use constant leaf model for mean forest - int leaf_dim_mean = 1; // dimension of the leaf for mean forest - bool exponentiated_leaf_mean = false; // whether to exponentiate leaf predictions for mean forest - int num_features_subsample_mean = 0; // number of features to subsample for each mean forest split (0 means no subsampling) - double a_sigma2_mean = 3.0; // shape parameter for inverse gamma prior on mean forest leaf scale - double b_sigma2_mean = -1.0; // scale parameter for inverse gamma prior on mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) - double sigma2_mean_init = -1.0; // initial value of mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) - std::vector var_weights_mean; // variable weights for mean forest splits (should be same length as number of covariates in the dataset) - bool sample_sigma2_leaf_mean = false; // whether to sample mean forest leaf scale (if false, it will be fixed at sigma2_mean_init) + int num_trees_mean = 200; // number of trees in the mean forest + double alpha_mean = 0.95; // alpha parameter for mean forest tree prior + double beta_mean = 2.0; // beta parameter for mean forest tree prior + int min_samples_leaf_mean = 5; // minimum number of samples per leaf for mean forest + int max_depth_mean = -1; // maximum depth for mean forest trees (-1 means no maximum) + bool leaf_constant_mean = true; // whether to use constant leaf model for mean forest + int leaf_dim_mean = 1; // dimension of the leaf for mean forest + bool exponentiated_leaf_mean = false; // whether to exponentiate leaf predictions for mean forest + int num_features_subsample_mean = 0; // number of features to subsample for each mean forest split (0 means no subsampling) + double a_sigma2_mean = 3.0; // shape parameter for inverse gamma prior on mean forest leaf scale + double b_sigma2_mean = -1.0; // scale parameter for inverse gamma prior on mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) + double sigma2_mean_init = -1.0; // initial value of mean forest leaf scale (-1 is a sentinel value that triggers a data-informed calibration based on the variance of the outcome and the number of trees) + std::vector var_weights_mean; // variable weights for mean forest splits (should be same length as number of covariates in the dataset) + bool sample_sigma2_leaf_mean = false; // whether to sample mean forest leaf scale (if false, it will be fixed at sigma2_mean_init) + std::vector sweep_update_indices_mean; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) // Variance forest parameters - int num_trees_variance = 0; // number of trees in the variance forest - double leaf_prior_calibration_param = 1.5; // calibration parameter for variance forest leaf prior - double shape_variance_forest = -1.0; // shape parameter for variance forest leaf model (calibrated internally based on leaf_prior_calibration_param if set to sentinel value of -1) - double scale_variance_forest = -1.0; // scale parameter for variance forest leaf model (calibrated internally based on leaf_prior_calibration_param if set to sentinel value of -1) - double alpha_variance = 0.5; // alpha parameter for variance forest tree prior - double beta_variance = 2.0; // beta parameter for variance forest tree prior - int min_samples_leaf_variance = 5; // minimum number of samples per leaf for variance forest - int max_depth_variance = -1; // maximum depth for variance forest trees (-1 means no maximum) - bool leaf_constant_variance = true; // whether to use constant leaf model for variance forest - int leaf_dim_variance = 1; // dimension of the leaf for variance forest (should be 1 if leaf_constant_variance=true) - bool exponentiated_leaf_variance = true; // whether to exponentiate leaf predictions for variance forest - int num_features_subsample_variance = 0; // number of features to subsample for each variance forest split (0 means no subsampling) - std::vector var_weights_variance; // variable weights for variance forest splits (should be same length as number of covariates in the dataset) + int num_trees_variance = 0; // number of trees in the variance forest + double leaf_prior_calibration_param = 1.5; // calibration parameter for variance forest leaf prior + double shape_variance_forest = -1.0; // shape parameter for variance forest leaf model (calibrated internally based on leaf_prior_calibration_param if set to sentinel value of -1) + double scale_variance_forest = -1.0; // scale parameter for variance forest leaf model (calibrated internally based on leaf_prior_calibration_param if set to sentinel value of -1) + double alpha_variance = 0.5; // alpha parameter for variance forest tree prior + double beta_variance = 2.0; // beta parameter for variance forest tree prior + int min_samples_leaf_variance = 5; // minimum number of samples per leaf for variance forest + int max_depth_variance = -1; // maximum depth for variance forest trees (-1 means no maximum) + bool leaf_constant_variance = true; // whether to use constant leaf model for variance forest + int leaf_dim_variance = 1; // dimension of the leaf for variance forest (should be 1 if leaf_constant_variance=true) + bool exponentiated_leaf_variance = true; // whether to exponentiate leaf predictions for variance forest + int num_features_subsample_variance = 0; // number of features to subsample for each variance forest split (0 means no subsampling) + std::vector var_weights_variance; // variable weights for variance forest splits (should be same length as number of covariates in the dataset) + std::vector sweep_update_indices_variance; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) // TODO: Random effects parameters ... diff --git a/src/R_bart.cpp b/src/R_bart.cpp index e27189cc..0398c3fc 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -103,10 +103,15 @@ StochTree::BARTConfig convert_list_to_config(cpp11::list config) { output.feature_types.push_back(static_cast(i)); } } - SEXP sweep_update_indices_raw = static_cast(config["sweep_update_indices"]); - if (!Rf_isNull(sweep_update_indices_raw)) { - cpp11::integers sweep_update_indices_r_vec(sweep_update_indices_raw); - output.sweep_update_indices.assign(sweep_update_indices_r_vec.begin(), sweep_update_indices_r_vec.end()); + SEXP sweep_update_indices_mean_raw = static_cast(config["sweep_update_indices_mean"]); + if (!Rf_isNull(sweep_update_indices_mean_raw)) { + cpp11::integers sweep_update_indices_mean_r_vec(sweep_update_indices_mean_raw); + output.sweep_update_indices_mean.assign(sweep_update_indices_mean_r_vec.begin(), sweep_update_indices_mean_r_vec.end()); + } + SEXP sweep_update_indices_variance_raw = static_cast(config["sweep_update_indices_variance"]); + if (!Rf_isNull(sweep_update_indices_variance_raw)) { + cpp11::integers sweep_update_indices_variance_r_vec(sweep_update_indices_variance_raw); + output.sweep_update_indices_variance.assign(sweep_update_indices_variance_r_vec.begin(), sweep_update_indices_variance_r_vec.end()); } SEXP var_weights_mean_raw = static_cast(config["var_weights_mean"]); if (!Rf_isNull(var_weights_mean_raw)) { diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 8fedab06..8aa03b60 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -169,7 +169,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART GFRSampleOneIter( *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, *forest_dataset_, *residual_, *tree_prior_mean_, rng_, - config.var_weights_mean, config.sweep_update_indices, global_variance_, config.feature_types, + config.var_weights_mean, config.sweep_update_indices_mean, global_variance_, config.feature_types, config.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, /*num_features_subsample=*/data.p, config.num_threads); @@ -177,7 +177,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART MCMCSampleOneIter( *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, *forest_dataset_, *residual_, *tree_prior_mean_, rng_, - config.var_weights_mean, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, + config.var_weights_mean, config.sweep_update_indices_mean, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, /*num_threads=*/config.num_threads); } @@ -188,7 +188,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART GFRSampleOneIter( *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, *forest_dataset_, *residual_, *tree_prior_variance_, rng_, - config.var_weights_variance, config.sweep_update_indices, global_variance_, config.feature_types, + config.var_weights_variance, config.sweep_update_indices_variance, global_variance_, config.feature_types, config.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, /*num_features_subsample=*/data.p, config.num_threads); @@ -196,7 +196,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART MCMCSampleOneIter( *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, *forest_dataset_, *residual_, *tree_prior_variance_, rng_, - config.var_weights_variance, config.sweep_update_indices, global_variance_, /*keep_forest=*/keep_sample, + config.var_weights_variance, config.sweep_update_indices_variance, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, /*num_threads=*/config.num_threads); } From 35020b14834b8cc06c7dc2f92b9a7185a9b1a899 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 08:45:13 -0400 Subject: [PATCH 31/64] Updated R package to default to no cpp loop for now (unit tests crashing on incomplete implementation) --- R/bart.R | 4 ++-- man/bart.Rd | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/bart.R b/R/bart.R index ebf5a99c..c58493cd 100644 --- a/R/bart.R +++ b/R/bart.R @@ -157,7 +157,7 @@ NULL #' - `variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. #' - `variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. #' -#' @param run_cpp Whether or not to run the core C++ sampler. This is exposed as an argument for testing purposes, but in general should be left as `TRUE`. If `FALSE`, the function will run the previous version of the BART sampler in which the core loop logic was implemented in R, with C++ calls for most computationally intensive steps. +#' @param run_cpp Whether or not to run the core C++ sampler. Default `FALSE`, but will eventually be set to `TRUE`. #' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export @@ -206,7 +206,7 @@ bart <- function( mean_forest_params = list(), variance_forest_params = list(), random_effects_params = list(), - run_cpp = TRUE + run_cpp = FALSE ) { # Update general BART parameters general_params_default <- list( diff --git a/man/bart.Rd b/man/bart.Rd index cd7152e1..f6b1b15b 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -24,7 +24,7 @@ bart( mean_forest_params = list(), variance_forest_params = list(), random_effects_params = list(), - run_cpp = TRUE + run_cpp = FALSE ) } \arguments{ @@ -140,7 +140,7 @@ referred to internally in the C++ layer as "variance weights" (\code{var_weights \item \code{variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. }} -\item{run_cpp}{Whether or not to run the core C++ sampler. This is exposed as an argument for testing purposes, but in general should be left as \code{TRUE}. If \code{FALSE}, the function will run the previous version of the BART sampler in which the core loop logic was implemented in R, with C++ calls for most computationally intensive steps.} +\item{run_cpp}{Whether or not to run the core C++ sampler. Default \code{FALSE}, but will eventually be set to \code{TRUE}.} } \value{ List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). From 455ac9d622e83fc61df63043aad25234ff78a7cc Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 19:10:26 -0400 Subject: [PATCH 32/64] Initial python wrapper around C++ interface --- debug/benchmark_cpp_vs_py_sampler.py | 126 ++ src/R_bart.cpp | 2 +- src/py_stochtree.cpp | 316 ++++ stochtree/bart.py | 2138 ++++++++++++++------------ 4 files changed, 1566 insertions(+), 1016 deletions(-) create mode 100644 debug/benchmark_cpp_vs_py_sampler.py diff --git a/debug/benchmark_cpp_vs_py_sampler.py b/debug/benchmark_cpp_vs_py_sampler.py new file mode 100644 index 00000000..a4e94eb8 --- /dev/null +++ b/debug/benchmark_cpp_vs_py_sampler.py @@ -0,0 +1,126 @@ +"""Benchmark: C++ sampler loop vs. Python sampler loop. + +Compares runtime and test-set RMSE across run_cpp=True / False in BARTModel.sample(). + +Usage: + source venv/bin/activate + python debug/benchmark_cpp_vs_py_sampler.py +""" + +import time +import numpy as np +from stochtree import BARTModel + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +rng = np.random.default_rng(1234) + +n = 10000 +p = 10 +X = rng.uniform(size=(n, p)) +f_X = ( + np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -7.5, 0) + + np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -2.5, 0) + + np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 2.5, 0) + + np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 7.5, 0) +) +noise_sd = 1.0 +y = f_X + rng.normal(scale=noise_sd, size=n) + +test_frac = 0.2 +n_test = round(test_frac * n) +n_train = n - n_test +test_inds = rng.choice(n, size=n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) + +X_train, X_test = X[train_inds], X[test_inds] +y_train, y_test = y[train_inds], y[test_inds] +f_test = f_X[test_inds] + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr = 10 +num_mcmc = 100 +num_trees = 200 +n_reps = 3 + +print( + f"n_train={n_train} n_test={n_test} p={p} " + f"num_trees={num_trees} num_gfr={num_gfr} num_mcmc={num_mcmc} reps={n_reps}\n" +) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + RMSE +# --------------------------------------------------------------------------- +def run_once(run_cpp, num_gfr, num_mcmc, seed): + m = BARTModel() + t0 = time.perf_counter() + m.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=0, + num_mcmc=num_mcmc, + general_params={"random_seed": seed}, + mean_forest_params={"num_trees": num_trees}, + run_cpp=run_cpp, + ) + elapsed = time.perf_counter() - t0 + + yhat = m.y_hat_test.mean(axis=1) + rmse = np.sqrt(np.mean((yhat - y_test) ** 2)) + rmse_f = np.sqrt(np.mean((yhat - f_test) ** 2)) + return {"elapsed": elapsed, "rmse": rmse, "rmse_f": rmse_f} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds = [1000 + i for i in range(1, n_reps + 1)] + +results_cpp = [] +results_py = [] + +print("Running C++ sampler (run_cpp=True)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_cpp.append(run_once(run_cpp=True, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + +print("\nRunning Python sampler (run_cpp=False)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_py.append(run_once(run_cpp=False, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +def summarise(results): + elapsed = [r["elapsed"] for r in results] + rmse = [r["rmse"] for r in results] + rmse_f = [r["rmse_f"] for r in results] + return { + "elapsed_mean": np.mean(elapsed), "elapsed_sd": np.std(elapsed, ddof=1), + "rmse_mean": np.mean(rmse), + "rmse_f_mean": np.mean(rmse_f), + } + +s_cpp = summarise(results_cpp) +s_py = summarise(results_py) +rows = [("cpp (run_cpp=True)", s_cpp), ("py (run_cpp=False)", s_py)] + +print("\n--- Results ---") +print(f"{'Sampler':<22} {'Time (s)':>10} {'SD':>10} {'RMSE (obs)':>12} {'RMSE (f)':>12}") +print("-" * 72) +for label, s in rows: + print(f"{label:<22} {s['elapsed_mean']:>10.3f} {s['elapsed_sd']:>10.3f}" + f" {s['rmse_mean']:>12.4f} {s['rmse_f_mean']:>12.4f}") + +speedup = s_py["elapsed_mean"] / s_cpp["elapsed_mean"] +print(f"\nSpeedup (py / cpp): {speedup:.2f}x") +print( + f"RMSE delta (cpp - py): " + f"obs={s_cpp['rmse_mean'] - s_py['rmse_mean']:.4f} " + f"f={s_cpp['rmse_f_mean'] - s_py['rmse_f_mean']:.4f}" +) diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 0398c3fc..ff82059f 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -211,7 +211,7 @@ cpp11::writable::list bart_sample_cpp( int keep_every, int num_mcmc, cpp11::list config_input) { - // Create smart pointer to outcome object + // Create outcome object StochTree::BARTSamples results_raw = StochTree::BARTSamples(); // Extract pointers to raw data diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 020bfbaf..5c20e422 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1,7 +1,10 @@ #include #include +#include #include #include +#include +#include #include #include #include @@ -253,6 +256,13 @@ class ForestContainerCpp { is_leaf_constant_ = is_leaf_constant; is_exponentiated_ = is_exponentiated; } + ForestContainerCpp(std::unique_ptr forest_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { + forest_samples_ = std::move(forest_samples); + num_trees_ = num_trees; + output_dimension_ = output_dimension; + is_leaf_constant_ = is_leaf_constant; + is_exponentiated_ = is_exponentiated; + } ~ForestContainerCpp() {} void CombineForests(py::array_t forest_inds) { @@ -2146,6 +2156,288 @@ class JsonCpp { std::unique_ptr json_; }; +template +T get_config_scalar_default(py::dict& config_dict, const char* config_key, T default_value) { + return config_dict.contains(config_key) ? config_dict[config_key].cast() : default_value; +} + +inline StochTree::BARTConfig convert_dict_to_bart_config(py::dict config_dict) { + StochTree::BARTConfig output; + output.num_trees_mean = get_config_scalar_default(config_dict, "num_trees_mean", 200); + output.alpha_mean = get_config_scalar_default(config_dict, "alpha_mean", 0.95); + output.beta_mean = get_config_scalar_default(config_dict, "beta_mean", 2.0); + + // Global model parameters + output.num_trees_mean = get_config_scalar_default(config_dict, "num_trees_mean", 200); + output.standardize_outcome = get_config_scalar_default(config_dict, "standardize_outcome", true); + output.num_threads = get_config_scalar_default(config_dict, "num_threads", 1); + output.cutpoint_grid_size = get_config_scalar_default(config_dict, "cutpoint_grid_size", 100); + output.link_function = static_cast(get_config_scalar_default(config_dict, "link_function", 0)); + output.outcome_type = static_cast(get_config_scalar_default(config_dict, "outcome_type", 0)); + output.random_seed = get_config_scalar_default(config_dict, "random_seed", 1); + + // Global error variance parameters + output.a_sigma2_global = get_config_scalar_default(config_dict, "a_sigma2_global", 0.0); + output.b_sigma2_global = get_config_scalar_default(config_dict, "b_sigma2_global", 0.0); + output.sigma2_global_init = get_config_scalar_default(config_dict, "sigma2_global_init", 1.0); + output.sample_sigma2_global = get_config_scalar_default(config_dict, "sample_sigma2_global", true); + + // Mean forest parameters + output.num_trees_mean = get_config_scalar_default(config_dict, "num_trees_mean", 200); + output.alpha_mean = get_config_scalar_default(config_dict, "alpha_mean", 0.95); + output.beta_mean = get_config_scalar_default(config_dict, "beta_mean", 2.0); + output.min_samples_leaf_mean = get_config_scalar_default(config_dict, "min_samples_leaf_mean", 5); + output.max_depth_mean = get_config_scalar_default(config_dict, "max_depth_mean", -1); + output.leaf_constant_mean = get_config_scalar_default(config_dict, "leaf_constant_mean", true); + output.leaf_dim_mean = get_config_scalar_default(config_dict, "leaf_dim_mean", 1); + output.exponentiated_leaf_mean = get_config_scalar_default(config_dict, "exponentiated_leaf_mean", true); + output.num_features_subsample_mean = get_config_scalar_default(config_dict, "num_features_subsample_mean", 0); + output.a_sigma2_mean = get_config_scalar_default(config_dict, "a_sigma2_mean", 3.0); + output.b_sigma2_mean = get_config_scalar_default(config_dict, "b_sigma2_mean", -1.0); + output.sigma2_mean_init = get_config_scalar_default(config_dict, "sigma2_mean_init", -1.0); + output.sample_sigma2_leaf_mean = get_config_scalar_default(config_dict, "sample_sigma2_leaf_mean", false); + + // Variance forest parameters + output.num_trees_variance = get_config_scalar_default(config_dict, "num_trees_variance", 0); + output.leaf_prior_calibration_param = get_config_scalar_default(config_dict, "leaf_prior_calibration_param", 1.5); + output.shape_variance_forest = get_config_scalar_default(config_dict, "shape_variance_forest", -1.0); + output.scale_variance_forest = get_config_scalar_default(config_dict, "scale_variance_forest", -1.0); + output.alpha_variance = get_config_scalar_default(config_dict, "alpha_variance", 0.5); + output.beta_variance = get_config_scalar_default(config_dict, "beta_variance", 2.0); + output.min_samples_leaf_variance = get_config_scalar_default(config_dict, "min_samples_leaf_variance", 5); + output.max_depth_variance = get_config_scalar_default(config_dict, "max_depth_variance", -1); + output.leaf_constant_variance = get_config_scalar_default(config_dict, "leaf_constant_variance", true); + output.leaf_dim_variance = get_config_scalar_default(config_dict, "leaf_dim_variance", 1); + output.exponentiated_leaf_variance = get_config_scalar_default(config_dict, "exponentiated_leaf_variance", true); + output.num_features_subsample_variance = get_config_scalar_default(config_dict, "num_features_subsample_variance", 0); + + // Handle vector conversions separately + if (config_dict.contains("feature_types")) { + std::vector feature_types_vector = config_dict["feature_types"].cast>(); + for (auto item : feature_types_vector) { + output.feature_types.push_back(static_cast(item)); + } + } + if (config_dict.contains("sweep_update_indices_mean")) { + output.sweep_update_indices_mean = config_dict["sweep_update_indices_mean"].cast>(); + } + if (config_dict.contains("sweep_update_indices_variance")) { + output.sweep_update_indices_variance = config_dict["sweep_update_indices_variance"].cast>(); + } + if (config_dict.contains("var_weights_mean")) { + output.var_weights_mean = config_dict["var_weights_mean"].cast>(); + } + if (config_dict.contains("var_weights_variance")) { + output.var_weights_variance = config_dict["var_weights_variance"].cast>(); + } + return output; +} + +inline StochTree::BARTData convert_numpy_to_bart_data( + py::object X_train, + py::object y_train, + py::object X_test, + int n_train, + int n_test, + int p, + py::object basis_train, + py::object basis_test, + int basis_dim, + py::object obs_weights_train, + py::object obs_weights_test, + py::object rfx_group_ids_train, + py::object rfx_group_ids_test, + py::object rfx_basis_train, + py::object rfx_basis_test, + int rfx_num_groups, + int rfx_basis_dim) { + StochTree::BARTData output; + if (!X_train.is_none()) { + py::array_t X_train_array = X_train.cast>(); + output.X_train = static_cast(X_train_array.mutable_data()); + } + if (!y_train.is_none()) { + py::array_t y_train_array = y_train.cast>(); + output.y_train = static_cast(y_train_array.mutable_data()); + } + if (!X_test.is_none()) { + py::array_t X_test_array = X_test.cast>(); + output.X_test = static_cast(X_test_array.mutable_data()); + } + if (!basis_train.is_none()) { + py::array_t basis_train_array = basis_train.cast>(); + output.basis_train = static_cast(basis_train_array.mutable_data()); + } + if (!basis_test.is_none()) { + py::array_t basis_test_array = basis_test.cast>(); + output.basis_test = static_cast(basis_test_array.mutable_data()); + } + if (!obs_weights_train.is_none()) { + py::array_t obs_weights_train_array = obs_weights_train.cast>(); + output.obs_weights_train = static_cast(obs_weights_train_array.mutable_data()); + } + if (!obs_weights_test.is_none()) { + py::array_t obs_weights_test_array = obs_weights_test.cast>(); + output.obs_weights_test = static_cast(obs_weights_test_array.mutable_data()); + } + if (!rfx_group_ids_train.is_none()) { + py::array_t rfx_group_ids_train_array = rfx_group_ids_train.cast>(); + output.rfx_group_ids_train = static_cast(rfx_group_ids_train_array.mutable_data()); + } + if (!rfx_group_ids_test.is_none()) { + py::array_t rfx_group_ids_test_array = rfx_group_ids_test.cast>(); + output.rfx_group_ids_test = static_cast(rfx_group_ids_test_array.mutable_data()); + } + if (!rfx_basis_train.is_none()) { + py::array_t rfx_basis_train_array = rfx_basis_train.cast>(); + output.rfx_basis_train = static_cast(rfx_basis_train_array.mutable_data()); + } + if (!rfx_basis_test.is_none()) { + py::array_t rfx_basis_test_array = rfx_basis_test.cast>(); + output.rfx_basis_test = static_cast(rfx_basis_test_array.mutable_data()); + } + output.n_train = n_train; + output.n_test = n_test; + output.p = p; + output.basis_dim = basis_dim; + output.rfx_num_groups = rfx_num_groups; + output.rfx_basis_dim = rfx_basis_dim; + return output; +} + +inline py::dict convert_bart_results_to_dict( + StochTree::BARTSamples& results_raw, StochTree::BARTConfig& config) { + py::dict output; + + // Transfer ownership of mean forest pointers + if (results_raw.mean_forests != nullptr) { + output["mean_forests"] = py::cast(ForestContainerCpp(std::move(results_raw.mean_forests), config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean)); + } else { + output["mean_forests"] = py::none(); + } + + // Transfer ownership of variance forest pointers + if (results_raw.variance_forests != nullptr) { + output["variance_forests"] = py::cast(ForestContainerCpp(std::move(results_raw.variance_forests), config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance)); + } else { + output["variance_forests"] = py::none(); + } + + // Move parameter vector samples + + // Train set mean forest predictions + if (results_raw.mean_forest_predictions_train.empty()) { + output["mean_forest_predictions_train"] = py::none(); + } else { + auto input_vec = results_raw.mean_forest_predictions_train; + py::array_t array(input_vec.size()); + std::copy(input_vec.begin(), input_vec.end(), array.mutable_data()); + output["mean_forest_predictions_train"] = array; + } + + // Test set mean forest predictions + if (results_raw.mean_forest_predictions_test.empty()) { + output["mean_forest_predictions_test"] = py::none(); + } else { + auto input_vec = results_raw.mean_forest_predictions_test; + py::array_t array(input_vec.size()); + std::copy(input_vec.begin(), input_vec.end(), array.mutable_data()); + output["mean_forest_predictions_test"] = array; + } + + // Train set variance forest predictions + if (results_raw.variance_forest_predictions_train.empty()) { + output["variance_forest_predictions_train"] = py::none(); + } else { + auto input_vec = results_raw.variance_forest_predictions_train; + py::array_t array(input_vec.size()); + std::copy(input_vec.begin(), input_vec.end(), array.mutable_data()); + output["variance_forest_predictions_train"] = array; + } + + // Test set variance forest predictions + if (results_raw.variance_forest_predictions_test.empty()) { + output["variance_forest_predictions_test"] = py::none(); + } else { + auto input_vec = results_raw.variance_forest_predictions_test; + py::array_t array(input_vec.size()); + std::copy(input_vec.begin(), input_vec.end(), array.mutable_data()); + output["variance_forest_predictions_test"] = array; + } + + // Global error variance samples + if (results_raw.global_error_variance_samples.empty()) { + output["global_error_variance_samples"] = py::none(); + } else { + auto input_vec = results_raw.global_error_variance_samples; + py::array_t array(input_vec.size()); + std::copy(input_vec.begin(), input_vec.end(), array.mutable_data()); + output["global_error_variance_samples"] = array; + } + + // Leaf scale samples + if (results_raw.leaf_scale_samples.empty()) { + output["leaf_scale_samples"] = py::none(); + } else { + auto input_vec = results_raw.leaf_scale_samples; + py::array_t array(input_vec.size()); + std::copy(input_vec.begin(), input_vec.end(), array.mutable_data()); + output["leaf_scale_samples"] = array; + } + + // Unpack scalars + output["y_bar"] = results_raw.y_bar; + output["y_std"] = results_raw.y_std; + output["num_samples"] = results_raw.num_samples; + output["num_train"] = results_raw.num_train; + output["num_test"] = results_raw.num_test; + + return output; +} + +py::dict bart_sample_cpp( + py::object X_train, + py::object y_train, + py::object X_test, + int n_train, + int n_test, + int p, + py::object basis_train, + py::object basis_test, + int basis_dim, + py::object obs_weights_train, + py::object obs_weights_test, + py::object rfx_group_ids_train, + py::object rfx_group_ids_test, + py::object rfx_basis_train, + py::object rfx_basis_test, + int rfx_num_groups, + int rfx_basis_dim, + int num_gfr, + int num_burnin, + int keep_every, + int num_mcmc, + py::dict config_input) { + // Convert config dict to BARTConfig struct + StochTree::BARTConfig bart_config = convert_dict_to_bart_config(config_input); + + // Unpack pointers to input data to BARTData object + StochTree::BARTData bart_data = convert_numpy_to_bart_data(X_train, y_train, X_test, n_train, n_test, p, basis_train, basis_test, basis_dim, obs_weights_train, obs_weights_test, rfx_group_ids_train, rfx_group_ids_test, rfx_basis_train, rfx_basis_test, rfx_num_groups, rfx_basis_dim); + + // Create outcome object + StochTree::BARTSamples bart_results_raw = StochTree::BARTSamples(); + + // Initialize a BART sampler + StochTree::BARTSampler bart_sampler(bart_results_raw, bart_config, bart_data); + + // Run the sampler + bart_sampler.run_gfr(bart_results_raw, bart_config, bart_data, num_gfr, true); + bart_sampler.run_mcmc(bart_results_raw, bart_config, bart_data, num_burnin, keep_every, num_mcmc); + + // Convert results to Python dictionary + return convert_bart_results_to_dict(bart_results_raw, bart_config); +} + py::array_t cppComputeForestContainerLeafIndices(ForestContainerCpp& forest_container, py::array_t& covariates, py::array_t& forest_nums) { // Wrap an Eigen Map around the raw data of the covariate matrix StochTree::data_size_t num_obs = covariates.shape(0); @@ -2267,6 +2559,29 @@ void RandomEffectsTrackerCpp::RootReset(RandomEffectsModelCpp& rfx_model, Random PYBIND11_MODULE(stochtree_cpp, m) { m.def("cppComputeForestContainerLeafIndices", &cppComputeForestContainerLeafIndices, "Compute leaf indices of the forests in a forest container"); m.def("cppComputeForestMaxLeafIndex", &cppComputeForestMaxLeafIndex, "Compute max leaf index of a forest in a forest container"); + m.def("bart_sample_cpp", &bart_sample_cpp, "Run BART sampler in C++ implementation", + py::arg("X_train"), + py::arg("y_train"), + py::arg("X_test") = py::none(), + py::arg("n_train"), + py::arg("n_test"), + py::arg("p"), + py::arg("basis_train") = py::none(), + py::arg("basis_test") = py::none(), + py::arg("basis_dim"), + py::arg("obs_weights_train") = py::none(), + py::arg("obs_weights_test") = py::none(), + py::arg("rfx_group_ids_train") = py::none(), + py::arg("rfx_group_ids_test") = py::none(), + py::arg("rfx_basis_train") = py::none(), + py::arg("rfx_basis_test") = py::none(), + py::arg("rfx_num_groups"), + py::arg("rfx_basis_dim"), + py::arg("num_gfr"), + py::arg("num_burnin"), + py::arg("keep_every"), + py::arg("num_mcmc"), + py::arg("config_input")); py::class_(m, "JsonCpp") .def(py::init<>()) @@ -2537,6 +2852,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("UpdateLatentVariables", &OrdinalSamplerCpp::UpdateLatentVariables) .def("UpdateGammaParams", &OrdinalSamplerCpp::UpdateGammaParams) .def("UpdateCumulativeExpSums", &OrdinalSamplerCpp::UpdateCumulativeExpSums); + ; #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); diff --git a/stochtree/bart.py b/stochtree/bart.py index d9451c8c..cd8be874 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -33,6 +33,7 @@ _posterior_predictive_heuristic_multiplier, _summarize_interval, ) +from stochtree_cpp import bart_sample_cpp class BARTModel: @@ -95,6 +96,7 @@ def sample( mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, random_effects_params: Optional[Dict[str, Any]] = None, + run_cpp: bool = False, ) -> None: """Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis. @@ -169,6 +171,9 @@ def sample( counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used. + run_cpp : bool, optional + Whether to run the C++ implementation of the BART sampler. Defaults to `False`. + Returns ------- @@ -1049,204 +1054,7 @@ def sample( "Sampling global error variance not yet supported for models with variance forests, so the global error variance parameter will not be sampled in this model." ) sample_sigma2_global = False - - # Handle standardization, prior calibration, and initialization of forest - # differently for binary and continuous outcomes - if link_is_probit: - # Compute a probit-scale offset and fix scale to 1 - self.y_bar = norm.ppf(np.squeeze(np.mean(y_train))) - self.y_std = 1.0 - - # Set a pseudo outcome by subtracting mean(y_train) from y_train - resid_train = y_train - np.squeeze(np.mean(y_train)) - - # Set initial values of root nodes to 0.0 (in probit scale) - init_val_mean = 0.0 - - # Calibrate priors for sigma^2 and tau - # Set sigma2_init to 1, ignoring default provided - sigma2_init = 1.0 - current_sigma2 = sigma2_init - self.sigma2_init = sigma2_init - # Skip variance_forest_init, since variance forests are not supported with probit link - b_leaf = 1.0 / num_trees_mean if b_leaf is None else b_leaf - if self.has_basis: - if sigma2_leaf is None: - current_leaf_scale = np.zeros( - (self.num_basis, self.num_basis), dtype=float - ) - np.fill_diagonal( - current_leaf_scale, - 2.0 / num_trees_mean, - ) - elif isinstance(sigma2_leaf, float): - current_leaf_scale = np.zeros( - (self.num_basis, self.num_basis), dtype=float - ) - np.fill_diagonal(current_leaf_scale, sigma2_leaf) - elif isinstance(sigma2_leaf, np.ndarray): - if sigma2_leaf.ndim != 2: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" - ) - if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" - ) - if sigma2_leaf.shape[0] != self.num_basis: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" - ) - current_leaf_scale = sigma2_leaf - else: - raise ValueError( - "sigma2_leaf must be either a scalar or a 2d symmetric numpy array" - ) - else: - if sigma2_leaf is None: - current_leaf_scale = np.array([[2.0 / num_trees_mean]]) - elif isinstance(sigma2_leaf, float): - current_leaf_scale = np.array([[sigma2_leaf]]) - elif isinstance(sigma2_leaf, np.ndarray): - if sigma2_leaf.ndim != 2: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" - ) - if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" - ) - if sigma2_leaf.shape[0] != 1: - raise ValueError( - "sigma2_leaf must be a 1x1 numpy array for this leaf model" - ) - current_leaf_scale = sigma2_leaf - else: - raise ValueError( - "sigma2_leaf must be either a scalar or a 2d numpy array" - ) - elif link_is_cloglog: - # Fix offset to 0 and scale to 1 - self.y_bar = 0 - self.y_std = 1 - - # Remap outcomes to start from 0 - resid_train = y_train - np.min(unique_outcomes) - cloglog_num_categories = int(np.max(resid_train)) + 1 - - # Set initial values of root nodes to 0.0 (in linear scale) - init_val_mean = 0.0 - - # Calibrate priors for sigma^2 and tau - sigma2_init = 1.0 - current_sigma2 = sigma2_init - self.sigma2_init = sigma2_init - current_leaf_scale = np.array([[2.0 / num_trees_mean]]) - - # Set first cutpoint to 0 for identifiability - cloglog_cutpoint_0 = 0.0 - - # Set shape and rate parameters for conditional gamma model - cloglog_forest_shape = 2.0 - cloglog_forest_rate = 2.0 - else: - # Standardize if requested - if self.standardize: - self.y_bar = np.squeeze(np.mean(y_train)) - self.y_std = np.squeeze(np.std(y_train)) - else: - self.y_bar = 0 - self.y_std = 1 - - # Compute residual value - resid_train = (y_train - self.y_bar) / self.y_std - - # Compute initial value of root nodes in mean forest - init_val_mean = np.squeeze(np.mean(resid_train)) - - # Calibrate priors for global sigma^2 and sigma2_leaf - if not sigma2_init: - sigma2_init = 1.0 * np.var(resid_train) - if not variance_forest_leaf_init: - variance_forest_leaf_init = 0.6 * np.var(resid_train) - current_sigma2 = sigma2_init - self.sigma2_init = sigma2_init - if self.include_mean_forest: - b_leaf = ( - np.squeeze(np.var(resid_train)) / num_trees_mean - if b_leaf is None - else b_leaf - ) - if self.has_basis: - if sigma2_leaf is None: - current_leaf_scale = np.zeros( - (self.num_basis, self.num_basis), dtype=float - ) - np.fill_diagonal( - current_leaf_scale, - np.squeeze(np.var(resid_train)) / num_trees_mean, - ) - elif isinstance(sigma2_leaf, float): - current_leaf_scale = np.zeros( - (self.num_basis, self.num_basis), dtype=float - ) - np.fill_diagonal(current_leaf_scale, sigma2_leaf) - elif isinstance(sigma2_leaf, np.ndarray): - if sigma2_leaf.ndim != 2: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" - ) - if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" - ) - if sigma2_leaf.shape[0] != self.num_basis: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" - ) - current_leaf_scale = sigma2_leaf - else: - raise ValueError( - "sigma2_leaf must be either a scalar or a 2d symmetric numpy array" - ) - else: - if sigma2_leaf is None: - current_leaf_scale = np.array([ - [np.squeeze(np.var(resid_train)) / num_trees_mean] - ]) - elif isinstance(sigma2_leaf, float): - current_leaf_scale = np.array([[sigma2_leaf]]) - elif isinstance(sigma2_leaf, np.ndarray): - if sigma2_leaf.ndim != 2: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" - ) - if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: - raise ValueError( - "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" - ) - if sigma2_leaf.shape[0] != 1: - raise ValueError( - "sigma2_leaf must be a 1x1 numpy array for this leaf model" - ) - current_leaf_scale = sigma2_leaf - else: - raise ValueError( - "sigma2_leaf must be either a scalar or a 2d numpy array" - ) - else: - current_leaf_scale = np.array([[1.0]]) - if self.include_variance_forest: - if not a_forest: - a_forest = num_trees_variance / a_0**2 + 0.5 - if not b_forest: - b_forest = num_trees_variance / a_0**2 - else: - if not a_forest: - a_forest = 1.0 - if not b_forest: - b_forest = 1.0 - + # Runtime checks on RFX group ids self.has_rfx = False has_rfx_test = False @@ -1288,825 +1096,1125 @@ def sample( elif self.rfx_model_spec == "intercept_only": if rfx_basis_test is None: rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) - # Set up random effects structures - if self.has_rfx: - # Prior parameters - if rfx_working_parameter_prior_mean is None: - if num_rfx_components == 1: - alpha_init = np.array([0.0], dtype=float) - elif num_rfx_components > 1: - alpha_init = np.zeros(num_rfx_components, dtype=float) - else: - raise ValueError("There must be at least 1 random effect component") - else: - alpha_init = _expand_dims_1d( - rfx_working_parameter_prior_mean, num_rfx_components - ) - - if rfx_group_parameter_prior_mean is None: - xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) - else: - xi_init = _expand_dims_2d( - rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups - ) - - if rfx_working_parameter_prior_cov is None: - sigma_alpha_init = np.identity(num_rfx_components) - else: - sigma_alpha_init = _expand_dims_2d_diag( - rfx_working_parameter_prior_cov, num_rfx_components - ) - - if rfx_group_parameter_prior_cov is None: - sigma_xi_init = np.identity(num_rfx_components) - else: - sigma_xi_init = _expand_dims_2d_diag( - rfx_group_parameter_prior_cov, num_rfx_components - ) - - sigma_xi_shape = rfx_variance_prior_shape - sigma_xi_scale = rfx_variance_prior_scale - - # Random effects sampling data structures - rfx_dataset_train = RandomEffectsDataset() - rfx_dataset_train.add_group_labels(rfx_group_ids_train) - rfx_dataset_train.add_basis(rfx_basis_train) - rfx_tracker = RandomEffectsTracker(rfx_group_ids_train) - rfx_model = RandomEffectsModel(num_rfx_components, num_rfx_groups) - rfx_model.set_working_parameter(alpha_init) - rfx_model.set_group_parameters(xi_init) - rfx_model.set_working_parameter_covariance(sigma_alpha_init) - rfx_model.set_group_parameter_covariance(sigma_xi_init) - rfx_model.set_variance_prior_shape(sigma_xi_shape) - rfx_model.set_variance_prior_scale(sigma_xi_scale) - self.rfx_container = RandomEffectsContainer() - self.rfx_container.load_new_container( - num_rfx_components, num_rfx_groups, rfx_tracker - ) - # Container of variance parameter samples - self.num_gfr = num_gfr - self.num_burnin = num_burnin - self.num_mcmc = num_mcmc - self.num_chains = num_chains - self.keep_every = keep_every - num_temp_samples = num_gfr + num_burnin + num_mcmc * keep_every - num_retained_samples = num_mcmc * num_chains - # Delete GFR samples from these containers after the fact if desired - # if keep_gfr: - # num_retained_samples += num_gfr - num_retained_samples += num_gfr - if keep_burnin: - num_retained_samples += num_burnin * num_chains - self.num_samples = num_retained_samples - self.sample_sigma2_global = sample_sigma2_global - self.sample_sigma2_leaf = sample_sigma2_leaf - if sample_sigma2_global: - self.global_var_samples = np.empty(self.num_samples, dtype=np.float64) - if sample_sigma2_leaf: - self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64) - if self.include_mean_forest: - yhat_train_raw = np.empty( - (self.n_train, self.num_samples), dtype=np.float64 - ) - if self.include_variance_forest: - sigma2_x_train_raw = np.empty( - (self.n_train, self.num_samples), dtype=np.float64 - ) - sample_counter = -1 - - # Forest Dataset (covariates and optional basis) - forest_dataset_train = Dataset() - forest_dataset_train.add_covariates(X_train_processed) - if self.has_basis: - forest_dataset_train.add_basis(leaf_basis_train) - if observation_weights is not None: - forest_dataset_train.add_variance_weights(observation_weights_) - if self.has_test: - forest_dataset_test = Dataset() - forest_dataset_test.add_covariates(X_test_processed) - if self.has_basis: - forest_dataset_test.add_basis(leaf_basis_test) - - # Residual - residual_train = Residual(resid_train) - - # C++ and Numpy random number generator - if random_seed is None: - cpp_rng = RNG(-1) - self.rng = np.random.default_rng() - else: - cpp_rng = RNG(random_seed) - self.rng = np.random.default_rng(random_seed) - - # Set variance leaf model type (currently only one option) - leaf_model_variance_forest = 3 - leaf_dimension_variance = 1 - - # Determine the mean forest leaf model type - if link_is_cloglog and not self.has_basis: - leaf_model_mean_forest = 4 - leaf_dimension_mean = 1 - elif not self.has_basis: - leaf_model_mean_forest = 0 - leaf_dimension_mean = 1 - elif self.num_basis == 1: - leaf_model_mean_forest = 1 - leaf_dimension_mean = 1 + if run_cpp: + bart_config = { + "standardize_outcome": self.standardize, + "num_threads": num_threads, + "cutpoint_grid_size": cutpoint_grid_size, + "link_function": 0 if self.outcome_model.link == "identity" else (1 if self.outcome_model.link == "probit" else 2), + "outcome_type": 0 if self.outcome_model.outcome == "continuous" else (1 if self.outcome_model.outcome == "binary" else 2), + "random_seed": random_seed, + "a_sigma2_global": a_global, + "b_sigma2_global": b_global, + "sigma2_global_init": sigma2_init, + "sample_sigma2_global": sample_sigma2_global, + "num_trees_mean": num_trees_mean, + "alpha_mean": alpha_mean, + "beta_mean": beta_mean, + "min_samples_leaf_mean": min_samples_leaf_mean, + "max_depth_mean": max_depth_mean, + "leaf_constant_mean": True if self.has_basis else False, + "leaf_dim_mean": self.num_basis if self.has_basis else 1, + "exponentiated_leaf_mean": False, + "num_features_subsample_mean": num_features_subsample_mean, + "a_sigma2_mean": a_leaf, + "b_sigma2_mean": b_leaf, + "sigma2_mean_init": sigma2_init, + "sample_sigma2_leaf_mean": sample_sigma2_leaf, + "num_trees_variance": num_trees_variance, + "leaf_prior_calibration_param": a_0, + "shape_variance_forest": a_forest, + "scale_variance_forest": b_forest, + "alpha_variance": alpha_variance, + "beta_variance": beta_variance, + "min_samples_leaf_variance": min_samples_leaf_variance, + "max_depth_variance": max_depth_variance, + "leaf_constant_variance": True, + "leaf_dim_variance": 1, + "exponentiated_leaf_variance": True, + "num_features_subsample_variance": num_features_subsample_variance, + "feature_types": feature_types.astype(int), + "sweep_update_indices_mean": list(range(num_trees_mean)) if num_trees_mean > 0 else None, + "sweep_update_indices_variance": list(range(num_trees_variance)) if num_trees_variance > 0 else None, + "var_weights_mean": variable_weights_mean, + "var_weights_variance": variable_weights_variance + } + + bart_results = bart_sample_cpp( + X_train = X_train_processed, + y_train = y_train, + X_test = X_test_processed if self.has_test else None, + n_train = X_train_processed.shape[0], + n_test = X_test_processed.shape[0] if self.has_test else 0, + p = X_train_processed.shape[1], + basis_train = leaf_basis_train if self.has_basis else None, + basis_test = leaf_basis_test if self.has_basis and self.has_test else None, + basis_dim = self.num_basis if self.has_basis else None, + obs_weights_train = observation_weights if observation_weights is not None else None, + obs_weights_test = None, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + rfx_num_groups = num_rfx_groups if self.has_rfx else None, + rfx_basis_dim = self.num_rfx_basis if self.has_rfx else None, + num_gfr = num_gfr, + num_burnin = num_burnin, + keep_every = keep_every, + num_mcmc = num_mcmc, + config_input = bart_config + ) + + self.forest_container_mean = ForestContainer(num_trees=num_trees_mean, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) + self.forest_container_mean.forest_container_cpp = bart_results["forest_container_mean"] + if self.include_variance_forest: + self.forest_container_variance = ForestContainer(num_trees=num_trees_variance, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) + self.forest_container_variance.forest_container_cpp = bart_results["forest_container_variance"] + if sample_sigma2_global: + self.global_var_samples = bart_results["global_var_samples"] * self.y_std * self.y_std + if sample_sigma2_leaf: + self.leaf_scale_samples = bart_results["leaf_scale_samples"] + mean_forest_preds_train = bart_results["mean_forest_predictions_train"] + mean_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") + self.y_hat_train = mean_forest_preds_train * self.y_std + self.y_bar + if self.has_test: + mean_forest_preds_test = bart_results["mean_forest_predictions_test"] + mean_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") + self.y_hat_test = mean_forest_preds_test * self.y_std + self.y_bar + if self.include_variance_forest: + variance_forest_preds_train = bart_results["variance_forest_predictions_train"] + variance_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") + self.variance_forest_preds_train = variance_forest_preds_train * self.y_std * self.y_std + if self.has_test: + variance_forest_preds_test = bart_results["variance_forest_predictions_test"] + variance_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") + self.variance_forest_preds_test = variance_forest_preds_test * self.y_std * self.y_std + + self.num_samples = bart_results["num_samples"] + self.sampled = True + + return self + else: - leaf_model_mean_forest = 2 - leaf_dimension_mean = self.num_basis - - # Sampling data structures - global_model_config = GlobalModelConfig(global_error_variance=current_sigma2) - if self.include_mean_forest: - forest_model_config_mean = ForestModelConfig( - num_trees=num_trees_mean, - num_features=num_features, - num_observations=self.n_train, - feature_types=feature_types, - variable_weights=variable_weights_mean, - leaf_dimension=leaf_dimension_mean, - alpha=alpha_mean, - beta=beta_mean, - min_samples_leaf=min_samples_leaf_mean, - max_depth=max_depth_mean, - leaf_model_type=leaf_model_mean_forest, - leaf_model_scale=current_leaf_scale, - cutpoint_grid_size=cutpoint_grid_size, - num_features_subsample=num_features_subsample_mean, - ) - if link_is_cloglog: - forest_model_config_mean.update_cloglog_forest_shape(cloglog_forest_shape) - forest_model_config_mean.update_cloglog_forest_rate(cloglog_forest_rate) - forest_sampler_mean = ForestSampler( - forest_dataset_train, - global_model_config, - forest_model_config_mean, - ) - if self.include_variance_forest: - forest_model_config_variance = ForestModelConfig( - num_trees=num_trees_variance, - num_features=num_features, - num_observations=self.n_train, - feature_types=feature_types, - variable_weights=variable_weights_variance, - leaf_dimension=leaf_dimension_variance, - alpha=alpha_variance, - beta=beta_variance, - min_samples_leaf=min_samples_leaf_variance, - max_depth=max_depth_variance, - leaf_model_type=leaf_model_variance_forest, - cutpoint_grid_size=cutpoint_grid_size, - variance_forest_shape=a_forest, - variance_forest_scale=b_forest, - num_features_subsample=num_features_subsample_variance, - ) - forest_sampler_variance = ForestSampler( - forest_dataset_train, - global_model_config, - forest_model_config_variance, - ) - - # Container of forest samples - if self.include_mean_forest: - self.forest_container_mean = ( - ForestContainer(num_trees_mean, 1, True, False) - if not self.has_basis - else ForestContainer(num_trees_mean, self.num_basis, False, False) - ) - active_forest_mean = ( - Forest(num_trees_mean, 1, True, False) - if not self.has_basis - else Forest(num_trees_mean, self.num_basis, False, False) - ) - if self.include_variance_forest: - self.forest_container_variance = ForestContainer( - num_trees_variance, 1, True, True - ) - active_forest_variance = Forest(num_trees_variance, 1, True, True) - - # Variance samplers - if self.sample_sigma2_global: - global_var_model = GlobalVarianceModel() - if self.sample_sigma2_leaf: - leaf_var_model = LeafVarianceModel() - - # Initialize the leaves of each tree in the mean forest - if self.include_mean_forest: - if self.has_basis: - init_val_mean = np.repeat(0.0, leaf_basis_train.shape[1]) - else: - init_val_mean = np.array([0.0]) - forest_sampler_mean.prepare_for_sampler( - forest_dataset_train, - residual_train, - active_forest_mean, - leaf_model_mean_forest, - init_val_mean, - ) - - # Initialize the leaves of each tree in the variance forest - if self.include_variance_forest: - init_val_variance = np.array([variance_forest_leaf_init]) - forest_sampler_variance.prepare_for_sampler( - forest_dataset_train, - residual_train, - active_forest_variance, - leaf_model_variance_forest, - init_val_variance, - ) - - # Initialize auxiliary data and ordinal sampler for cloglog - if link_is_cloglog: - ordinal_sampler = OrdinalSampler() - train_size = self.n_train - - # Slot 0: Latent variable Z (size n_train) - forest_dataset_train.add_auxiliary_dimension(train_size) - # Slot 1: Forest predictions eta (size n_train) - forest_dataset_train.add_auxiliary_dimension(train_size) - # Slot 2: Log-scale cutpoints gamma (size cloglog_num_categories - 1) - forest_dataset_train.add_auxiliary_dimension(cloglog_num_categories - 1) - # Slot 3: Cumulative exp cutpoints seg (size cloglog_num_categories) - forest_dataset_train.add_auxiliary_dimension(cloglog_num_categories) - - # Initialize all slots to 0 - for j in range(train_size): - forest_dataset_train.set_auxiliary_data_value(0, j, 0.0) - forest_dataset_train.set_auxiliary_data_value(1, j, 0.0) - for j in range(cloglog_num_categories - 1): - forest_dataset_train.set_auxiliary_data_value(2, j, 0.0) - - # Compute initial cumulative exp sums - ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) - - # Allocate storage for cutpoint samples - cloglog_cutpoint_samples = np.full( - (cloglog_num_categories - 1, num_retained_samples), np.nan - ) - # Run GFR (warm start) if specified - if self.num_gfr > 0: - for i in range(self.num_gfr): - # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC - # keep_sample = keep_gfr - keep_sample = True - if keep_sample: - sample_counter += 1 - if self.include_mean_forest: - if link_is_probit: - # Sample latent probit variable z | - - outcome_pred = active_forest_mean.predict(forest_dataset_train) - if self.has_rfx: - rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) - outcome_pred = outcome_pred + rfx_pred - # Full probit-scale predictor: forest learns z - y_bar, so add y_bar back - eta_pred = outcome_pred + self.y_bar - mu0 = eta_pred[y_train[:, 0] == 0] - mu1 = eta_pred[y_train[:, 0] == 1] - n0 = np.sum(y_train[:, 0] == 0) - n1 = np.sum(y_train[:, 0] == 1) - u0 = self.rng.uniform( - low=0.0, - high=norm.cdf(0 - mu0), - size=n0, - ) - u1 = self.rng.uniform( - low=norm.cdf(0 - mu1), - high=1.0, - size=n1, - ) - resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) - resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) - - # Update outcome: center z by y_bar before passing to forest - new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred - residual_train.update_data(new_outcome) - - # Sample the mean forest - forest_sampler_mean.sample_one_iteration( - self.forest_container_mean, - active_forest_mean, - forest_dataset_train, - residual_train, - cpp_rng, - global_model_config, - forest_model_config_mean, - keep_sample, - True, - num_threads, - ) - - # Cache train set predictions since they are already computed during sampling - if keep_sample: - yhat_train_raw[:, sample_counter] = ( - forest_sampler_mean.get_cached_forest_predictions() - ) - - # Sample the variance forest - if self.include_variance_forest: - forest_sampler_variance.sample_one_iteration( - self.forest_container_variance, - active_forest_variance, - forest_dataset_train, - residual_train, - cpp_rng, - global_model_config, - forest_model_config_variance, - keep_sample, - True, - num_threads, - ) - - # Cache train set predictions since they are already computed during sampling - if keep_sample: - sigma2_x_train_raw[:, sample_counter] = ( - forest_sampler_variance.get_cached_forest_predictions() - ) - - # Sample variance parameters (if requested) - if self.sample_sigma2_global: - current_sigma2 = global_var_model.sample_one_iteration( - residual_train, cpp_rng, a_global, b_global - ) - global_model_config.update_global_error_variance(current_sigma2) - if keep_sample: - self.global_var_samples[sample_counter] = current_sigma2 - if self.sample_sigma2_leaf: - current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( - active_forest_mean, cpp_rng, a_leaf, b_leaf - ) - forest_model_config_mean.update_leaf_model_scale(current_leaf_scale) - if keep_sample: - self.leaf_scale_samples[sample_counter] = current_leaf_scale[ - 0, 0 - ] - - # Sample random effects - if self.has_rfx: - rfx_model.sample( - rfx_dataset_train, - residual_train, - rfx_tracker, - self.rfx_container, - keep_sample, - current_sigma2, - cpp_rng, - ) - - # Cloglog Gibbs updates - if link_is_cloglog: - # Update auxiliary data slot 1 with current forest predictions - forest_pred_current = forest_sampler_mean.get_cached_forest_predictions() - for j in range(train_size): - forest_dataset_train.set_auxiliary_data_value(1, j, forest_pred_current[j]) - - # Sample latent z_i's using truncated exponential - ordinal_sampler.update_latent_variables( - forest_dataset_train, residual_train, cpp_rng - ) - - # Sample gamma parameters (cutpoints) - ordinal_sampler.update_gamma_params( - forest_dataset_train, - residual_train, - cloglog_forest_shape, - cloglog_forest_rate, - cloglog_cutpoint_0, - cpp_rng, - ) - - # Update cumulative sum of exp(gamma) values - ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) - - # Retain cutpoint draw - if keep_sample: - cloglog_cutpoint_samples[:, sample_counter] = ( - forest_dataset_train.get_auxiliary_data_vector(2) - ) - - # Run MCMC - if self.num_burnin + self.num_mcmc > 0: - for chain_num in range(num_chains): - if num_gfr > 0: - forest_ind = num_gfr - chain_num - 1 - # Reset mean forest - if self.include_mean_forest: - active_forest_mean.reset(self.forest_container_mean, forest_ind) - forest_sampler_mean.reconstitute_from_forest( - active_forest_mean, - forest_dataset_train, - residual_train, - True, - ) - # Reset leaf scale - if sample_sigma2_leaf: - leaf_scale_double = self.leaf_scale_samples[ - forest_ind - ] - current_leaf_scale[0, 0] = leaf_scale_double - forest_model_config_mean.update_leaf_model_scale( - leaf_scale_double - ) - # Reset variance forest - if self.include_variance_forest: - active_forest_variance.reset( - self.forest_container_variance, forest_ind - ) - forest_sampler_variance.reconstitute_from_forest( - active_forest_variance, - forest_dataset_train, - residual_train, - False, - ) - # Reset global error scale - if sample_sigma2_global: - current_sigma2 = self.global_var_samples[forest_ind] - global_model_config.update_global_error_variance(current_sigma2) - # Reset random effects - if self.has_rfx: - rfx_model.reset(self.rfx_container, forest_ind, sigma_alpha_init) - rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, self.rfx_container) - # Reset cloglog auxiliary data - if link_is_cloglog: - # Reset cutpoints from saved GFR samples - current_cutpoints = cloglog_cutpoint_samples[:, forest_ind] - for j in range(len(current_cutpoints)): - forest_dataset_train.set_auxiliary_data_value(2, j, current_cutpoints[j]) - ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) - # Reset forest predictions by re-predicting from active forest - active_forest_preds = active_forest_mean.predict(forest_dataset_train) - for j in range(train_size): - forest_dataset_train.set_auxiliary_data_value(1, j, active_forest_preds[j]) - # Latent variables must be reset to 0 and burnt in - forest_dataset_train.set_auxiliary_data_value(0, j, 0.0) - elif has_prev_model: - warmstart_index = previous_model_warmstart_sample_num - chain_num if previous_model_decrement else previous_model_warmstart_sample_num - # Reset mean forest - if self.include_mean_forest: - active_forest_mean.reset( - previous_bart_model.forest_container_mean, - warmstart_index, - ) - forest_sampler_mean.reconstitute_from_forest( - active_forest_mean, - forest_dataset_train, - residual_train, - True, - ) - # Reset leaf scale - if sample_sigma2_leaf and previous_leaf_var_samples is not None: - leaf_scale_double = previous_leaf_var_samples[ - warmstart_index - ] - current_leaf_scale[0, 0] = leaf_scale_double - forest_model_config_mean.update_leaf_model_scale( - leaf_scale_double - ) - # Reset variance forest - if self.include_variance_forest: - active_forest_variance.reset( - previous_bart_model.forest_container_variance, - warmstart_index, - ) - forest_sampler_variance.reconstitute_from_forest( - active_forest_variance, - forest_dataset_train, - residual_train, - True, - ) - # Reset global error scale - if self.sample_sigma2_global: - current_sigma2 = previous_global_var_samples[ - warmstart_index - ] - global_model_config.update_global_error_variance(current_sigma2) - # Reset random effects - if self.has_rfx: - rfx_model.reset(previous_bart_model.rfx_container, warmstart_index, sigma_alpha_init) - rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, previous_bart_model.rfx_container) - # Reset cloglog auxiliary data from previous model - if link_is_cloglog: - previous_cloglog_cutpoint_samples = getattr( - previous_bart_model, "cloglog_cutpoint_samples", None - ) - if previous_cloglog_cutpoint_samples is not None: - current_cutpoints = previous_cloglog_cutpoint_samples[:, warmstart_index] - for j in range(len(current_cutpoints)): - forest_dataset_train.set_auxiliary_data_value(2, j, current_cutpoints[j]) - ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) - active_forest_preds = active_forest_mean.predict(forest_dataset_train) - for j in range(train_size): - forest_dataset_train.set_auxiliary_data_value(1, j, active_forest_preds[j]) - # Latent variables must be reset to 0 and burnt in - forest_dataset_train.set_auxiliary_data_value(0, j, 0.0) - else: - # Reset mean forest - if self.include_mean_forest: - active_forest_mean.reset_root() - if init_val_mean.shape[0] == 1: - active_forest_mean.set_root_leaves( - init_val_mean[0] / num_trees_mean - ) - else: - active_forest_mean.set_root_leaves( - init_val_mean / num_trees_mean - ) - forest_sampler_mean.reconstitute_from_forest( - active_forest_mean, - forest_dataset_train, - residual_train, - True, - ) - # Reset mean forest leaf scale - if sample_sigma2_leaf and previous_leaf_var_samples is not None: - current_leaf_scale[0, 0] = sigma2_leaf - forest_model_config_mean.update_leaf_model_scale( - current_leaf_scale - ) - if link_is_cloglog: - # Reset all cloglog parameters to default values - for j in range(train_size): - forest_dataset_train.set_auxiliary_data_value(1, j, 0.0) - forest_dataset_train.set_auxiliary_data_value(0, j, 0.0) - # Initialize log-scale cutpoints to 0 - initial_gamma = np.zeros(cloglog_num_categories - 1) - for j in range(cloglog_num_categories - 1): - forest_dataset_train.set_auxiliary_data_value( - 2, - j, - initial_gamma[j] - ) - # Convert to cumulative exponentiated cutpoints - ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) - # Reset variance forest - if self.include_variance_forest: - active_forest_variance.reset_root() - active_forest_variance.set_root_leaves( - log(variance_forest_leaf_init) / num_trees_variance - ) - forest_sampler_variance.reconstitute_from_forest( - active_forest_variance, - forest_dataset_train, - residual_train, - False, - ) - # Reset global error scale - if self.sample_sigma2_global: - current_sigma2 = sigma2_init - global_model_config.update_global_error_variance(current_sigma2) - # Reset random effects terms - if self.has_rfx: - rfx_model.root_reset(alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - rfx_tracker.root_reset(rfx_model, rfx_dataset_train, residual_train, self.rfx_container) - # Sample MCMC and burnin for each chain - for i in range(self.num_gfr, num_temp_samples): - is_mcmc = i + 1 > num_gfr + num_burnin - if is_mcmc: - mcmc_counter = i - num_gfr - num_burnin + 1 - if mcmc_counter % keep_every == 0: - keep_sample = True - else: - keep_sample = False - else: - if keep_burnin: - keep_sample = True - else: - keep_sample = False - if keep_sample: - sample_counter += 1 - - if self.include_mean_forest: - if link_is_probit: - # Sample latent probit variable z | - - outcome_pred = active_forest_mean.predict( - forest_dataset_train - ) - if self.has_rfx: - rfx_pred = rfx_model.predict( - rfx_dataset_train, rfx_tracker - ) - outcome_pred = outcome_pred + rfx_pred - # Full probit-scale predictor: forest learns z - y_bar, so add y_bar back - eta_pred = outcome_pred + self.y_bar - mu0 = eta_pred[y_train[:, 0] == 0] - mu1 = eta_pred[y_train[:, 0] == 1] - n0 = np.sum(y_train[:, 0] == 0) - n1 = np.sum(y_train[:, 0] == 1) - u0 = self.rng.uniform( - low=0.0, - high=norm.cdf(0 - mu0), - size=n0, - ) - u1 = self.rng.uniform( - low=norm.cdf(0 - mu1), - high=1.0, - size=n1, - ) - resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) - resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) - - # Update outcome: center z by y_bar before passing to forest - new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred - residual_train.update_data(new_outcome) - - # Sample the mean forest - forest_sampler_mean.sample_one_iteration( - self.forest_container_mean, - active_forest_mean, - forest_dataset_train, - residual_train, - cpp_rng, - global_model_config, - forest_model_config_mean, - keep_sample, - False, - num_threads, - ) - - if keep_sample: - yhat_train_raw[:, sample_counter] = ( - forest_sampler_mean.get_cached_forest_predictions() - ) - - # Sample the variance forest - if self.include_variance_forest: - forest_sampler_variance.sample_one_iteration( - self.forest_container_variance, - active_forest_variance, - forest_dataset_train, - residual_train, - cpp_rng, - global_model_config, - forest_model_config_variance, - keep_sample, - False, - num_threads, - ) - - if keep_sample: - sigma2_x_train_raw[:, sample_counter] = ( - forest_sampler_variance.get_cached_forest_predictions() - ) - - # Sample variance parameters (if requested) - if self.sample_sigma2_global: - current_sigma2 = global_var_model.sample_one_iteration( - residual_train, cpp_rng, a_global, b_global - ) - global_model_config.update_global_error_variance(current_sigma2) - if keep_sample: - self.global_var_samples[sample_counter] = current_sigma2 - if self.sample_sigma2_leaf: - current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( - active_forest_mean, cpp_rng, a_leaf, b_leaf - ) - forest_model_config_mean.update_leaf_model_scale( - current_leaf_scale - ) - if keep_sample: - self.leaf_scale_samples[sample_counter] = ( - current_leaf_scale[0, 0] - ) - - # Sample random effects - if self.has_rfx: - rfx_model.sample( - rfx_dataset_train, - residual_train, - rfx_tracker, - self.rfx_container, - keep_sample, - current_sigma2, - cpp_rng, - ) - - # Cloglog Gibbs updates - if link_is_cloglog: - # Update auxiliary data slot 1 with current forest predictions - forest_pred_current = forest_sampler_mean.get_cached_forest_predictions() - for j in range(train_size): - forest_dataset_train.set_auxiliary_data_value(1, j, forest_pred_current[j]) - - # Sample latent z_i's using truncated exponential - ordinal_sampler.update_latent_variables( - forest_dataset_train, residual_train, cpp_rng - ) - - # Sample gamma parameters (cutpoints) - ordinal_sampler.update_gamma_params( - forest_dataset_train, - residual_train, - cloglog_forest_shape, - cloglog_forest_rate, - cloglog_cutpoint_0, - cpp_rng, - ) - - # Update cumulative sum of exp(gamma) values - ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) - - # Retain cutpoint draw - if keep_sample: - cloglog_cutpoint_samples[:, sample_counter] = ( - forest_dataset_train.get_auxiliary_data_vector(2) - ) - - # Mark the model as sampled - self.sampled = True - - # Remove GFR samples if they are not to be retained - if not keep_gfr and num_gfr > 0: - for i in range(num_gfr): - if self.include_mean_forest: - self.forest_container_mean.delete_sample(0) - if self.include_variance_forest: - self.forest_container_variance.delete_sample(0) - if self.has_rfx: - self.rfx_container.delete_sample(0) - if self.sample_sigma2_global: - self.global_var_samples = self.global_var_samples[num_gfr:] - if self.sample_sigma2_leaf: - self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:] - if self.include_mean_forest: - yhat_train_raw = yhat_train_raw[:, num_gfr:] - if self.include_variance_forest: - sigma2_x_train_raw = sigma2_x_train_raw[:, num_gfr:] - if link_is_cloglog: - cloglog_cutpoint_samples = cloglog_cutpoint_samples[:, num_gfr:] - self.num_samples -= num_gfr - - # Store cloglog results (cutpoints only for ordinal, num_categories always) - if link_is_cloglog: - self.cloglog_num_categories = cloglog_num_categories - if not outcome_is_binary: - self.cloglog_cutpoint_samples = cloglog_cutpoint_samples - - # Store predictions - if self.sample_sigma2_global: - self.global_var_samples = self.global_var_samples * self.y_std * self.y_std - - if self.sample_sigma2_leaf: - self.leaf_scale_samples = self.leaf_scale_samples - - if self.include_mean_forest: - self.y_hat_train = yhat_train_raw * self.y_std + self.y_bar - if self.has_test: - yhat_test_raw = self.forest_container_mean.forest_container_cpp.Predict( - forest_dataset_test.dataset_cpp - ) - self.y_hat_test = yhat_test_raw * self.y_std + self.y_bar - - # TODO: make rfx_preds_train and rfx_preds_test persistent properties - if self.has_rfx: - rfx_preds_train = ( - self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) - * self.y_std - ) - if has_rfx_test: - rfx_preds_test = ( - self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) - * self.y_std - ) - if self.include_mean_forest: - self.y_hat_train = self.y_hat_train + rfx_preds_train - if self.has_test: - self.y_hat_test = self.y_hat_test + rfx_preds_test - else: - self.y_hat_train = rfx_preds_train - if self.has_test: - self.y_hat_test = rfx_preds_test - - if self.include_variance_forest: - if self.sample_sigma2_global: - self.sigma2_x_train = np.empty_like(sigma2_x_train_raw) - for i in range(self.num_samples): - self.sigma2_x_train[:, i] = ( - np.exp(sigma2_x_train_raw[:, i]) * self.global_var_samples[i] - ) - else: - self.sigma2_x_train = ( - np.exp(sigma2_x_train_raw) - * self.sigma2_init - * self.y_std - * self.y_std - ) - if self.has_test: - sigma2_x_test_raw = ( - self.forest_container_variance.forest_container_cpp.Predict( - forest_dataset_test.dataset_cpp - ) - ) - if self.sample_sigma2_global: - self.sigma2_x_test = sigma2_x_test_raw - for i in range(self.num_samples): - self.sigma2_x_test[:, i] = ( - sigma2_x_test_raw[:, i] * self.global_var_samples[i] - ) - else: - self.sigma2_x_test = ( - sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std - ) + + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if link_is_probit: + # Compute a probit-scale offset and fix scale to 1 + self.y_bar = norm.ppf(np.squeeze(np.mean(y_train))) + self.y_std = 1.0 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train = y_train - np.squeeze(np.mean(y_train)) + + # Set initial values of root nodes to 0.0 (in probit scale) + init_val_mean = 0.0 + + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init = 1.0 + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init + # Skip variance_forest_init, since variance forests are not supported with probit link + b_leaf = 1.0 / num_trees_mean if b_leaf is None else b_leaf + if self.has_basis: + if sigma2_leaf is None: + current_leaf_scale = np.zeros( + (self.num_basis, self.num_basis), dtype=float + ) + np.fill_diagonal( + current_leaf_scale, + 2.0 / num_trees_mean, + ) + elif isinstance(sigma2_leaf, float): + current_leaf_scale = np.zeros( + (self.num_basis, self.num_basis), dtype=float + ) + np.fill_diagonal(current_leaf_scale, sigma2_leaf) + elif isinstance(sigma2_leaf, np.ndarray): + if sigma2_leaf.ndim != 2: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma2_leaf.shape[0] != self.num_basis: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" + ) + current_leaf_scale = sigma2_leaf + else: + raise ValueError( + "sigma2_leaf must be either a scalar or a 2d symmetric numpy array" + ) + else: + if sigma2_leaf is None: + current_leaf_scale = np.array([[2.0 / num_trees_mean]]) + elif isinstance(sigma2_leaf, float): + current_leaf_scale = np.array([[sigma2_leaf]]) + elif isinstance(sigma2_leaf, np.ndarray): + if sigma2_leaf.ndim != 2: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma2_leaf.shape[0] != 1: + raise ValueError( + "sigma2_leaf must be a 1x1 numpy array for this leaf model" + ) + current_leaf_scale = sigma2_leaf + else: + raise ValueError( + "sigma2_leaf must be either a scalar or a 2d numpy array" + ) + elif link_is_cloglog: + # Fix offset to 0 and scale to 1 + self.y_bar = 0 + self.y_std = 1 + + # Remap outcomes to start from 0 + resid_train = y_train - np.min(unique_outcomes) + cloglog_num_categories = int(np.max(resid_train)) + 1 + + # Set initial values of root nodes to 0.0 (in linear scale) + init_val_mean = 0.0 + + # Calibrate priors for sigma^2 and tau + sigma2_init = 1.0 + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init + current_leaf_scale = np.array([[2.0 / num_trees_mean]]) + + # Set first cutpoint to 0 for identifiability + cloglog_cutpoint_0 = 0.0 + + # Set shape and rate parameters for conditional gamma model + cloglog_forest_shape = 2.0 + cloglog_forest_rate = 2.0 + else: + # Standardize if requested + if self.standardize: + self.y_bar = np.squeeze(np.mean(y_train)) + self.y_std = np.squeeze(np.std(y_train)) + else: + self.y_bar = 0 + self.y_std = 1 + + # Compute residual value + resid_train = (y_train - self.y_bar) / self.y_std + + # Compute initial value of root nodes in mean forest + init_val_mean = np.squeeze(np.mean(resid_train)) + + # Calibrate priors for global sigma^2 and sigma2_leaf + if not sigma2_init: + sigma2_init = 1.0 * np.var(resid_train) + if not variance_forest_leaf_init: + variance_forest_leaf_init = 0.6 * np.var(resid_train) + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init + if self.include_mean_forest: + b_leaf = ( + np.squeeze(np.var(resid_train)) / num_trees_mean + if b_leaf is None + else b_leaf + ) + if self.has_basis: + if sigma2_leaf is None: + current_leaf_scale = np.zeros( + (self.num_basis, self.num_basis), dtype=float + ) + np.fill_diagonal( + current_leaf_scale, + np.squeeze(np.var(resid_train)) / num_trees_mean, + ) + elif isinstance(sigma2_leaf, float): + current_leaf_scale = np.zeros( + (self.num_basis, self.num_basis), dtype=float + ) + np.fill_diagonal(current_leaf_scale, sigma2_leaf) + elif isinstance(sigma2_leaf, np.ndarray): + if sigma2_leaf.ndim != 2: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma2_leaf.shape[0] != self.num_basis: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" + ) + current_leaf_scale = sigma2_leaf + else: + raise ValueError( + "sigma2_leaf must be either a scalar or a 2d symmetric numpy array" + ) + else: + if sigma2_leaf is None: + current_leaf_scale = np.array([ + [np.squeeze(np.var(resid_train)) / num_trees_mean] + ]) + elif isinstance(sigma2_leaf, float): + current_leaf_scale = np.array([[sigma2_leaf]]) + elif isinstance(sigma2_leaf, np.ndarray): + if sigma2_leaf.ndim != 2: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: + raise ValueError( + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma2_leaf.shape[0] != 1: + raise ValueError( + "sigma2_leaf must be a 1x1 numpy array for this leaf model" + ) + current_leaf_scale = sigma2_leaf + else: + raise ValueError( + "sigma2_leaf must be either a scalar or a 2d numpy array" + ) + else: + current_leaf_scale = np.array([[1.0]]) + if self.include_variance_forest: + if not a_forest: + a_forest = num_trees_variance / a_0**2 + 0.5 + if not b_forest: + b_forest = num_trees_variance / a_0**2 + else: + if not a_forest: + a_forest = 1.0 + if not b_forest: + b_forest = 1.0 + + # Set up random effects structures + if self.has_rfx: + # Prior parameters + if rfx_working_parameter_prior_mean is None: + if num_rfx_components == 1: + alpha_init = np.array([0.0], dtype=float) + elif num_rfx_components > 1: + alpha_init = np.zeros(num_rfx_components, dtype=float) + else: + raise ValueError("There must be at least 1 random effect component") + else: + alpha_init = _expand_dims_1d( + rfx_working_parameter_prior_mean, num_rfx_components + ) + + if rfx_group_parameter_prior_mean is None: + xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) + else: + xi_init = _expand_dims_2d( + rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups + ) + + if rfx_working_parameter_prior_cov is None: + sigma_alpha_init = np.identity(num_rfx_components) + else: + sigma_alpha_init = _expand_dims_2d_diag( + rfx_working_parameter_prior_cov, num_rfx_components + ) + + if rfx_group_parameter_prior_cov is None: + sigma_xi_init = np.identity(num_rfx_components) + else: + sigma_xi_init = _expand_dims_2d_diag( + rfx_group_parameter_prior_cov, num_rfx_components + ) + + sigma_xi_shape = rfx_variance_prior_shape + sigma_xi_scale = rfx_variance_prior_scale + + # Random effects sampling data structures + rfx_dataset_train = RandomEffectsDataset() + rfx_dataset_train.add_group_labels(rfx_group_ids_train) + rfx_dataset_train.add_basis(rfx_basis_train) + rfx_tracker = RandomEffectsTracker(rfx_group_ids_train) + rfx_model = RandomEffectsModel(num_rfx_components, num_rfx_groups) + rfx_model.set_working_parameter(alpha_init) + rfx_model.set_group_parameters(xi_init) + rfx_model.set_working_parameter_covariance(sigma_alpha_init) + rfx_model.set_group_parameter_covariance(sigma_xi_init) + rfx_model.set_variance_prior_shape(sigma_xi_shape) + rfx_model.set_variance_prior_scale(sigma_xi_scale) + self.rfx_container = RandomEffectsContainer() + self.rfx_container.load_new_container( + num_rfx_components, num_rfx_groups, rfx_tracker + ) + + # Container of variance parameter samples + self.num_gfr = num_gfr + self.num_burnin = num_burnin + self.num_mcmc = num_mcmc + self.num_chains = num_chains + self.keep_every = keep_every + num_temp_samples = num_gfr + num_burnin + num_mcmc * keep_every + num_retained_samples = num_mcmc * num_chains + # Delete GFR samples from these containers after the fact if desired + # if keep_gfr: + # num_retained_samples += num_gfr + num_retained_samples += num_gfr + if keep_burnin: + num_retained_samples += num_burnin * num_chains + self.num_samples = num_retained_samples + self.sample_sigma2_global = sample_sigma2_global + self.sample_sigma2_leaf = sample_sigma2_leaf + if sample_sigma2_global: + self.global_var_samples = np.empty(self.num_samples, dtype=np.float64) + if sample_sigma2_leaf: + self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64) + if self.include_mean_forest: + yhat_train_raw = np.empty( + (self.n_train, self.num_samples), dtype=np.float64 + ) + if self.include_variance_forest: + sigma2_x_train_raw = np.empty( + (self.n_train, self.num_samples), dtype=np.float64 + ) + sample_counter = -1 + + # Forest Dataset (covariates and optional basis) + forest_dataset_train = Dataset() + forest_dataset_train.add_covariates(X_train_processed) + if self.has_basis: + forest_dataset_train.add_basis(leaf_basis_train) + if observation_weights is not None: + forest_dataset_train.add_variance_weights(observation_weights_) + if self.has_test: + forest_dataset_test = Dataset() + forest_dataset_test.add_covariates(X_test_processed) + if self.has_basis: + forest_dataset_test.add_basis(leaf_basis_test) + + # Residual + residual_train = Residual(resid_train) + + # C++ and Numpy random number generator + if random_seed is None: + cpp_rng = RNG(-1) + self.rng = np.random.default_rng() + else: + cpp_rng = RNG(random_seed) + self.rng = np.random.default_rng(random_seed) + + # Set variance leaf model type (currently only one option) + leaf_model_variance_forest = 3 + leaf_dimension_variance = 1 + + # Determine the mean forest leaf model type + if link_is_cloglog and not self.has_basis: + leaf_model_mean_forest = 4 + leaf_dimension_mean = 1 + elif not self.has_basis: + leaf_model_mean_forest = 0 + leaf_dimension_mean = 1 + elif self.num_basis == 1: + leaf_model_mean_forest = 1 + leaf_dimension_mean = 1 + else: + leaf_model_mean_forest = 2 + leaf_dimension_mean = self.num_basis + + # Sampling data structures + global_model_config = GlobalModelConfig(global_error_variance=current_sigma2) + if self.include_mean_forest: + forest_model_config_mean = ForestModelConfig( + num_trees=num_trees_mean, + num_features=num_features, + num_observations=self.n_train, + feature_types=feature_types, + variable_weights=variable_weights_mean, + leaf_dimension=leaf_dimension_mean, + alpha=alpha_mean, + beta=beta_mean, + min_samples_leaf=min_samples_leaf_mean, + max_depth=max_depth_mean, + leaf_model_type=leaf_model_mean_forest, + leaf_model_scale=current_leaf_scale, + cutpoint_grid_size=cutpoint_grid_size, + num_features_subsample=num_features_subsample_mean, + ) + if link_is_cloglog: + forest_model_config_mean.update_cloglog_forest_shape(cloglog_forest_shape) + forest_model_config_mean.update_cloglog_forest_rate(cloglog_forest_rate) + forest_sampler_mean = ForestSampler( + forest_dataset_train, + global_model_config, + forest_model_config_mean, + ) + if self.include_variance_forest: + forest_model_config_variance = ForestModelConfig( + num_trees=num_trees_variance, + num_features=num_features, + num_observations=self.n_train, + feature_types=feature_types, + variable_weights=variable_weights_variance, + leaf_dimension=leaf_dimension_variance, + alpha=alpha_variance, + beta=beta_variance, + min_samples_leaf=min_samples_leaf_variance, + max_depth=max_depth_variance, + leaf_model_type=leaf_model_variance_forest, + cutpoint_grid_size=cutpoint_grid_size, + variance_forest_shape=a_forest, + variance_forest_scale=b_forest, + num_features_subsample=num_features_subsample_variance, + ) + forest_sampler_variance = ForestSampler( + forest_dataset_train, + global_model_config, + forest_model_config_variance, + ) + + # Container of forest samples + if self.include_mean_forest: + self.forest_container_mean = ( + ForestContainer(num_trees_mean, 1, True, False) + if not self.has_basis + else ForestContainer(num_trees_mean, self.num_basis, False, False) + ) + active_forest_mean = ( + Forest(num_trees_mean, 1, True, False) + if not self.has_basis + else Forest(num_trees_mean, self.num_basis, False, False) + ) + if self.include_variance_forest: + self.forest_container_variance = ForestContainer( + num_trees_variance, 1, True, True + ) + active_forest_variance = Forest(num_trees_variance, 1, True, True) + + # Variance samplers + if self.sample_sigma2_global: + global_var_model = GlobalVarianceModel() + if self.sample_sigma2_leaf: + leaf_var_model = LeafVarianceModel() + + # Initialize the leaves of each tree in the mean forest + if self.include_mean_forest: + if self.has_basis: + init_val_mean = np.repeat(0.0, leaf_basis_train.shape[1]) + else: + init_val_mean = np.array([0.0]) + forest_sampler_mean.prepare_for_sampler( + forest_dataset_train, + residual_train, + active_forest_mean, + leaf_model_mean_forest, + init_val_mean, + ) + + # Initialize the leaves of each tree in the variance forest + if self.include_variance_forest: + init_val_variance = np.array([variance_forest_leaf_init]) + forest_sampler_variance.prepare_for_sampler( + forest_dataset_train, + residual_train, + active_forest_variance, + leaf_model_variance_forest, + init_val_variance, + ) + + # Initialize auxiliary data and ordinal sampler for cloglog + if link_is_cloglog: + ordinal_sampler = OrdinalSampler() + train_size = self.n_train + + # Slot 0: Latent variable Z (size n_train) + forest_dataset_train.add_auxiliary_dimension(train_size) + # Slot 1: Forest predictions eta (size n_train) + forest_dataset_train.add_auxiliary_dimension(train_size) + # Slot 2: Log-scale cutpoints gamma (size cloglog_num_categories - 1) + forest_dataset_train.add_auxiliary_dimension(cloglog_num_categories - 1) + # Slot 3: Cumulative exp cutpoints seg (size cloglog_num_categories) + forest_dataset_train.add_auxiliary_dimension(cloglog_num_categories) + + # Initialize all slots to 0 + for j in range(train_size): + forest_dataset_train.set_auxiliary_data_value(0, j, 0.0) + forest_dataset_train.set_auxiliary_data_value(1, j, 0.0) + for j in range(cloglog_num_categories - 1): + forest_dataset_train.set_auxiliary_data_value(2, j, 0.0) + + # Compute initial cumulative exp sums + ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) + + # Allocate storage for cutpoint samples + cloglog_cutpoint_samples = np.full( + (cloglog_num_categories - 1, num_retained_samples), np.nan + ) + # Run GFR (warm start) if specified + if self.num_gfr > 0: + for i in range(self.num_gfr): + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample = keep_gfr + keep_sample = True + if keep_sample: + sample_counter += 1 + if self.include_mean_forest: + if link_is_probit: + # Sample latent probit variable z | - + outcome_pred = active_forest_mean.predict(forest_dataset_train) + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + outcome_pred = outcome_pred + rfx_pred + # Full probit-scale predictor: forest learns z - y_bar, so add y_bar back + eta_pred = outcome_pred + self.y_bar + mu0 = eta_pred[y_train[:, 0] == 0] + mu1 = eta_pred[y_train[:, 0] == 1] + n0 = np.sum(y_train[:, 0] == 0) + n1 = np.sum(y_train[:, 0] == 1) + u0 = self.rng.uniform( + low=0.0, + high=norm.cdf(0 - mu0), + size=n0, + ) + u1 = self.rng.uniform( + low=norm.cdf(0 - mu1), + high=1.0, + size=n1, + ) + resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) + resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) + + # Update outcome: center z by y_bar before passing to forest + new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred + residual_train.update_data(new_outcome) + + # Sample the mean forest + forest_sampler_mean.sample_one_iteration( + self.forest_container_mean, + active_forest_mean, + forest_dataset_train, + residual_train, + cpp_rng, + global_model_config, + forest_model_config_mean, + keep_sample, + True, + num_threads, + ) + + # Cache train set predictions since they are already computed during sampling + if keep_sample: + yhat_train_raw[:, sample_counter] = ( + forest_sampler_mean.get_cached_forest_predictions() + ) + + # Sample the variance forest + if self.include_variance_forest: + forest_sampler_variance.sample_one_iteration( + self.forest_container_variance, + active_forest_variance, + forest_dataset_train, + residual_train, + cpp_rng, + global_model_config, + forest_model_config_variance, + keep_sample, + True, + num_threads, + ) + + # Cache train set predictions since they are already computed during sampling + if keep_sample: + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) + + # Sample variance parameters (if requested) + if self.sample_sigma2_global: + current_sigma2 = global_var_model.sample_one_iteration( + residual_train, cpp_rng, a_global, b_global + ) + global_model_config.update_global_error_variance(current_sigma2) + if keep_sample: + self.global_var_samples[sample_counter] = current_sigma2 + if self.sample_sigma2_leaf: + current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( + active_forest_mean, cpp_rng, a_leaf, b_leaf + ) + forest_model_config_mean.update_leaf_model_scale(current_leaf_scale) + if keep_sample: + self.leaf_scale_samples[sample_counter] = current_leaf_scale[ + 0, 0 + ] + + # Sample random effects + if self.has_rfx: + rfx_model.sample( + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, + ) + + # Cloglog Gibbs updates + if link_is_cloglog: + # Update auxiliary data slot 1 with current forest predictions + forest_pred_current = forest_sampler_mean.get_cached_forest_predictions() + for j in range(train_size): + forest_dataset_train.set_auxiliary_data_value(1, j, forest_pred_current[j]) + + # Sample latent z_i's using truncated exponential + ordinal_sampler.update_latent_variables( + forest_dataset_train, residual_train, cpp_rng + ) + + # Sample gamma parameters (cutpoints) + ordinal_sampler.update_gamma_params( + forest_dataset_train, + residual_train, + cloglog_forest_shape, + cloglog_forest_rate, + cloglog_cutpoint_0, + cpp_rng, + ) + + # Update cumulative sum of exp(gamma) values + ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) + + # Retain cutpoint draw + if keep_sample: + cloglog_cutpoint_samples[:, sample_counter] = ( + forest_dataset_train.get_auxiliary_data_vector(2) + ) + + # Run MCMC + if self.num_burnin + self.num_mcmc > 0: + for chain_num in range(num_chains): + if num_gfr > 0: + forest_ind = num_gfr - chain_num - 1 + # Reset mean forest + if self.include_mean_forest: + active_forest_mean.reset(self.forest_container_mean, forest_ind) + forest_sampler_mean.reconstitute_from_forest( + active_forest_mean, + forest_dataset_train, + residual_train, + True, + ) + # Reset leaf scale + if sample_sigma2_leaf: + leaf_scale_double = self.leaf_scale_samples[ + forest_ind + ] + current_leaf_scale[0, 0] = leaf_scale_double + forest_model_config_mean.update_leaf_model_scale( + leaf_scale_double + ) + # Reset variance forest + if self.include_variance_forest: + active_forest_variance.reset( + self.forest_container_variance, forest_ind + ) + forest_sampler_variance.reconstitute_from_forest( + active_forest_variance, + forest_dataset_train, + residual_train, + False, + ) + # Reset global error scale + if sample_sigma2_global: + current_sigma2 = self.global_var_samples[forest_ind] + global_model_config.update_global_error_variance(current_sigma2) + # Reset random effects + if self.has_rfx: + rfx_model.reset(self.rfx_container, forest_ind, sigma_alpha_init) + rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, self.rfx_container) + # Reset cloglog auxiliary data + if link_is_cloglog: + # Reset cutpoints from saved GFR samples + current_cutpoints = cloglog_cutpoint_samples[:, forest_ind] + for j in range(len(current_cutpoints)): + forest_dataset_train.set_auxiliary_data_value(2, j, current_cutpoints[j]) + ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) + # Reset forest predictions by re-predicting from active forest + active_forest_preds = active_forest_mean.predict(forest_dataset_train) + for j in range(train_size): + forest_dataset_train.set_auxiliary_data_value(1, j, active_forest_preds[j]) + # Latent variables must be reset to 0 and burnt in + forest_dataset_train.set_auxiliary_data_value(0, j, 0.0) + elif has_prev_model: + warmstart_index = previous_model_warmstart_sample_num - chain_num if previous_model_decrement else previous_model_warmstart_sample_num + # Reset mean forest + if self.include_mean_forest: + active_forest_mean.reset( + previous_bart_model.forest_container_mean, + warmstart_index, + ) + forest_sampler_mean.reconstitute_from_forest( + active_forest_mean, + forest_dataset_train, + residual_train, + True, + ) + # Reset leaf scale + if sample_sigma2_leaf and previous_leaf_var_samples is not None: + leaf_scale_double = previous_leaf_var_samples[ + warmstart_index + ] + current_leaf_scale[0, 0] = leaf_scale_double + forest_model_config_mean.update_leaf_model_scale( + leaf_scale_double + ) + # Reset variance forest + if self.include_variance_forest: + active_forest_variance.reset( + previous_bart_model.forest_container_variance, + warmstart_index, + ) + forest_sampler_variance.reconstitute_from_forest( + active_forest_variance, + forest_dataset_train, + residual_train, + True, + ) + # Reset global error scale + if self.sample_sigma2_global: + current_sigma2 = previous_global_var_samples[ + warmstart_index + ] + global_model_config.update_global_error_variance(current_sigma2) + # Reset random effects + if self.has_rfx: + rfx_model.reset(previous_bart_model.rfx_container, warmstart_index, sigma_alpha_init) + rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, previous_bart_model.rfx_container) + # Reset cloglog auxiliary data from previous model + if link_is_cloglog: + previous_cloglog_cutpoint_samples = getattr( + previous_bart_model, "cloglog_cutpoint_samples", None + ) + if previous_cloglog_cutpoint_samples is not None: + current_cutpoints = previous_cloglog_cutpoint_samples[:, warmstart_index] + for j in range(len(current_cutpoints)): + forest_dataset_train.set_auxiliary_data_value(2, j, current_cutpoints[j]) + ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) + active_forest_preds = active_forest_mean.predict(forest_dataset_train) + for j in range(train_size): + forest_dataset_train.set_auxiliary_data_value(1, j, active_forest_preds[j]) + # Latent variables must be reset to 0 and burnt in + forest_dataset_train.set_auxiliary_data_value(0, j, 0.0) + else: + # Reset mean forest + if self.include_mean_forest: + active_forest_mean.reset_root() + if init_val_mean.shape[0] == 1: + active_forest_mean.set_root_leaves( + init_val_mean[0] / num_trees_mean + ) + else: + active_forest_mean.set_root_leaves( + init_val_mean / num_trees_mean + ) + forest_sampler_mean.reconstitute_from_forest( + active_forest_mean, + forest_dataset_train, + residual_train, + True, + ) + # Reset mean forest leaf scale + if sample_sigma2_leaf and previous_leaf_var_samples is not None: + current_leaf_scale[0, 0] = sigma2_leaf + forest_model_config_mean.update_leaf_model_scale( + current_leaf_scale + ) + if link_is_cloglog: + # Reset all cloglog parameters to default values + for j in range(train_size): + forest_dataset_train.set_auxiliary_data_value(1, j, 0.0) + forest_dataset_train.set_auxiliary_data_value(0, j, 0.0) + # Initialize log-scale cutpoints to 0 + initial_gamma = np.zeros(cloglog_num_categories - 1) + for j in range(cloglog_num_categories - 1): + forest_dataset_train.set_auxiliary_data_value( + 2, + j, + initial_gamma[j] + ) + # Convert to cumulative exponentiated cutpoints + ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) + # Reset variance forest + if self.include_variance_forest: + active_forest_variance.reset_root() + active_forest_variance.set_root_leaves( + log(variance_forest_leaf_init) / num_trees_variance + ) + forest_sampler_variance.reconstitute_from_forest( + active_forest_variance, + forest_dataset_train, + residual_train, + False, + ) + # Reset global error scale + if self.sample_sigma2_global: + current_sigma2 = sigma2_init + global_model_config.update_global_error_variance(current_sigma2) + # Reset random effects terms + if self.has_rfx: + rfx_model.root_reset(alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + rfx_tracker.root_reset(rfx_model, rfx_dataset_train, residual_train, self.rfx_container) + # Sample MCMC and burnin for each chain + for i in range(self.num_gfr, num_temp_samples): + is_mcmc = i + 1 > num_gfr + num_burnin + if is_mcmc: + mcmc_counter = i - num_gfr - num_burnin + 1 + if mcmc_counter % keep_every == 0: + keep_sample = True + else: + keep_sample = False + else: + if keep_burnin: + keep_sample = True + else: + keep_sample = False + if keep_sample: + sample_counter += 1 + + if self.include_mean_forest: + if link_is_probit: + # Sample latent probit variable z | - + outcome_pred = active_forest_mean.predict( + forest_dataset_train + ) + if self.has_rfx: + rfx_pred = rfx_model.predict( + rfx_dataset_train, rfx_tracker + ) + outcome_pred = outcome_pred + rfx_pred + # Full probit-scale predictor: forest learns z - y_bar, so add y_bar back + eta_pred = outcome_pred + self.y_bar + mu0 = eta_pred[y_train[:, 0] == 0] + mu1 = eta_pred[y_train[:, 0] == 1] + n0 = np.sum(y_train[:, 0] == 0) + n1 = np.sum(y_train[:, 0] == 1) + u0 = self.rng.uniform( + low=0.0, + high=norm.cdf(0 - mu0), + size=n0, + ) + u1 = self.rng.uniform( + low=norm.cdf(0 - mu1), + high=1.0, + size=n1, + ) + resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) + resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) + + # Update outcome: center z by y_bar before passing to forest + new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred + residual_train.update_data(new_outcome) + + # Sample the mean forest + forest_sampler_mean.sample_one_iteration( + self.forest_container_mean, + active_forest_mean, + forest_dataset_train, + residual_train, + cpp_rng, + global_model_config, + forest_model_config_mean, + keep_sample, + False, + num_threads, + ) + + if keep_sample: + yhat_train_raw[:, sample_counter] = ( + forest_sampler_mean.get_cached_forest_predictions() + ) + + # Sample the variance forest + if self.include_variance_forest: + forest_sampler_variance.sample_one_iteration( + self.forest_container_variance, + active_forest_variance, + forest_dataset_train, + residual_train, + cpp_rng, + global_model_config, + forest_model_config_variance, + keep_sample, + False, + num_threads, + ) + + if keep_sample: + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) + + # Sample variance parameters (if requested) + if self.sample_sigma2_global: + current_sigma2 = global_var_model.sample_one_iteration( + residual_train, cpp_rng, a_global, b_global + ) + global_model_config.update_global_error_variance(current_sigma2) + if keep_sample: + self.global_var_samples[sample_counter] = current_sigma2 + if self.sample_sigma2_leaf: + current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( + active_forest_mean, cpp_rng, a_leaf, b_leaf + ) + forest_model_config_mean.update_leaf_model_scale( + current_leaf_scale + ) + if keep_sample: + self.leaf_scale_samples[sample_counter] = ( + current_leaf_scale[0, 0] + ) + + # Sample random effects + if self.has_rfx: + rfx_model.sample( + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, + ) + + # Cloglog Gibbs updates + if link_is_cloglog: + # Update auxiliary data slot 1 with current forest predictions + forest_pred_current = forest_sampler_mean.get_cached_forest_predictions() + for j in range(train_size): + forest_dataset_train.set_auxiliary_data_value(1, j, forest_pred_current[j]) + + # Sample latent z_i's using truncated exponential + ordinal_sampler.update_latent_variables( + forest_dataset_train, residual_train, cpp_rng + ) + + # Sample gamma parameters (cutpoints) + ordinal_sampler.update_gamma_params( + forest_dataset_train, + residual_train, + cloglog_forest_shape, + cloglog_forest_rate, + cloglog_cutpoint_0, + cpp_rng, + ) + + # Update cumulative sum of exp(gamma) values + ordinal_sampler.update_cumulative_exp_sums(forest_dataset_train) + + # Retain cutpoint draw + if keep_sample: + cloglog_cutpoint_samples[:, sample_counter] = ( + forest_dataset_train.get_auxiliary_data_vector(2) + ) + + # Mark the model as sampled + self.sampled = True + + # Remove GFR samples if they are not to be retained + if not keep_gfr and num_gfr > 0: + for i in range(num_gfr): + if self.include_mean_forest: + self.forest_container_mean.delete_sample(0) + if self.include_variance_forest: + self.forest_container_variance.delete_sample(0) + if self.has_rfx: + self.rfx_container.delete_sample(0) + if self.sample_sigma2_global: + self.global_var_samples = self.global_var_samples[num_gfr:] + if self.sample_sigma2_leaf: + self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:] + if self.include_mean_forest: + yhat_train_raw = yhat_train_raw[:, num_gfr:] + if self.include_variance_forest: + sigma2_x_train_raw = sigma2_x_train_raw[:, num_gfr:] + if link_is_cloglog: + cloglog_cutpoint_samples = cloglog_cutpoint_samples[:, num_gfr:] + self.num_samples -= num_gfr + + # Store cloglog results (cutpoints only for ordinal, num_categories always) + if link_is_cloglog: + self.cloglog_num_categories = cloglog_num_categories + if not outcome_is_binary: + self.cloglog_cutpoint_samples = cloglog_cutpoint_samples + + # Store predictions + if self.sample_sigma2_global: + self.global_var_samples = self.global_var_samples * self.y_std * self.y_std + + if self.sample_sigma2_leaf: + self.leaf_scale_samples = self.leaf_scale_samples + + if self.include_mean_forest: + self.y_hat_train = yhat_train_raw * self.y_std + self.y_bar + if self.has_test: + yhat_test_raw = self.forest_container_mean.forest_container_cpp.Predict( + forest_dataset_test.dataset_cpp + ) + self.y_hat_test = yhat_test_raw * self.y_std + self.y_bar + + # TODO: make rfx_preds_train and rfx_preds_test persistent properties + if self.has_rfx: + rfx_preds_train = ( + self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) + * self.y_std + ) + if has_rfx_test: + rfx_preds_test = ( + self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) + * self.y_std + ) + if self.include_mean_forest: + self.y_hat_train = self.y_hat_train + rfx_preds_train + if self.has_test: + self.y_hat_test = self.y_hat_test + rfx_preds_test + else: + self.y_hat_train = rfx_preds_train + if self.has_test: + self.y_hat_test = rfx_preds_test + + if self.include_variance_forest: + if self.sample_sigma2_global: + self.sigma2_x_train = np.empty_like(sigma2_x_train_raw) + for i in range(self.num_samples): + self.sigma2_x_train[:, i] = ( + np.exp(sigma2_x_train_raw[:, i]) * self.global_var_samples[i] + ) + else: + self.sigma2_x_train = ( + np.exp(sigma2_x_train_raw) + * self.sigma2_init + * self.y_std + * self.y_std + ) + if self.has_test: + sigma2_x_test_raw = ( + self.forest_container_variance.forest_container_cpp.Predict( + forest_dataset_test.dataset_cpp + ) + ) + if self.sample_sigma2_global: + self.sigma2_x_test = sigma2_x_test_raw + for i in range(self.num_samples): + self.sigma2_x_test[:, i] = ( + sigma2_x_test_raw[:, i] * self.global_var_samples[i] + ) + else: + self.sigma2_x_test = ( + sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std + ) + return self def predict( self, From 2f3a3c0c417450093d19689125aaf8989b2110a8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 19:20:19 -0400 Subject: [PATCH 33/64] Fixed none-type initialization bugs in the python to C++ interface --- stochtree/bart.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index cd8be874..2563b3eb 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1107,7 +1107,7 @@ def sample( "random_seed": random_seed, "a_sigma2_global": a_global, "b_sigma2_global": b_global, - "sigma2_global_init": sigma2_init, + "sigma2_global_init": 1.0, # TODO: calibrate this before "sample_sigma2_global": sample_sigma2_global, "num_trees_mean": num_trees_mean, "alpha_mean": alpha_mean, @@ -1120,7 +1120,7 @@ def sample( "num_features_subsample_mean": num_features_subsample_mean, "a_sigma2_mean": a_leaf, "b_sigma2_mean": b_leaf, - "sigma2_mean_init": sigma2_init, + "sigma2_mean_init": -1.0, "sample_sigma2_leaf_mean": sample_sigma2_leaf, "num_trees_variance": num_trees_variance, "leaf_prior_calibration_param": a_0, @@ -1150,15 +1150,15 @@ def sample( p = X_train_processed.shape[1], basis_train = leaf_basis_train if self.has_basis else None, basis_test = leaf_basis_test if self.has_basis and self.has_test else None, - basis_dim = self.num_basis if self.has_basis else None, + basis_dim = self.num_basis if self.has_basis else 0, obs_weights_train = observation_weights if observation_weights is not None else None, obs_weights_test = None, rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, - rfx_num_groups = num_rfx_groups if self.has_rfx else None, - rfx_basis_dim = self.num_rfx_basis if self.has_rfx else None, + rfx_num_groups = num_rfx_groups if self.has_rfx else 0, + rfx_basis_dim = self.num_rfx_basis if self.has_rfx else 0, num_gfr = num_gfr, num_burnin = num_burnin, keep_every = keep_every, From 77b88d19e1cce7f844db8d4e1714a67dd0c1104b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 19:27:50 -0400 Subject: [PATCH 34/64] Explicitly convert numpy arrays to column-major ("Fortran style") --- stochtree/bart.py | 77 ++++++++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 2563b3eb..0c5498f6 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1098,6 +1098,7 @@ def sample( rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) if run_cpp: + # Arrange all config in a large python dictionary bart_config = { "standardize_outcome": self.standardize, "num_threads": num_threads, @@ -1141,15 +1142,31 @@ def sample( "var_weights_variance": variable_weights_variance } + # Remove None values from config (alternative is to check for Nones on the C++ side when unpacking into non-optional types) + bart_config = {k: v for k, v in bart_config.items() if v is not None} + + # Convert arrays to F-contiguous (column-major) before calling C++. + # convert_numpy_to_bart_data stores raw pointers into these arrays; if + # pybind11 has to make an F-contiguous copy (because the input is C-order) + # that copy is destroyed when the helper returns, leaving a dangling pointer. + # Passing already-F-contiguous arrays causes pybind11 to return a view of + # the original, which remains alive in this Python scope. + X_train_cpp = np.asfortranarray(X_train_processed) + y_train_cpp = np.asfortranarray(y_train) + X_test_cpp = np.asfortranarray(X_test_processed) if self.has_test else None + basis_train_cpp = np.asfortranarray(leaf_basis_train) if self.has_basis else None + basis_test_cpp = np.asfortranarray(leaf_basis_test) if self.has_basis and self.has_test else None + + # Run the BART sampler from C++ bart_results = bart_sample_cpp( - X_train = X_train_processed, - y_train = y_train, - X_test = X_test_processed if self.has_test else None, - n_train = X_train_processed.shape[0], - n_test = X_test_processed.shape[0] if self.has_test else 0, - p = X_train_processed.shape[1], - basis_train = leaf_basis_train if self.has_basis else None, - basis_test = leaf_basis_test if self.has_basis and self.has_test else None, + X_train = X_train_cpp, + y_train = y_train_cpp, + X_test = X_test_cpp, + n_train = X_train_cpp.shape[0], + n_test = X_test_cpp.shape[0] if self.has_test else 0, + p = X_train_cpp.shape[1], + basis_train = basis_train_cpp, + basis_test = basis_test_cpp, basis_dim = self.num_basis if self.has_basis else 0, obs_weights_train = observation_weights if observation_weights is not None else None, obs_weights_test = None, @@ -1166,31 +1183,37 @@ def sample( config_input = bart_config ) - self.forest_container_mean = ForestContainer(num_trees=num_trees_mean, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) - self.forest_container_mean.forest_container_cpp = bart_results["forest_container_mean"] + # Unpack mean forest results if self.include_variance_forest: - self.forest_container_variance = ForestContainer(num_trees=num_trees_variance, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) - self.forest_container_variance.forest_container_cpp = bart_results["forest_container_variance"] - if sample_sigma2_global: - self.global_var_samples = bart_results["global_var_samples"] * self.y_std * self.y_std - if sample_sigma2_leaf: - self.leaf_scale_samples = bart_results["leaf_scale_samples"] - mean_forest_preds_train = bart_results["mean_forest_predictions_train"] - mean_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") - self.y_hat_train = mean_forest_preds_train * self.y_std + self.y_bar - if self.has_test: + self.forest_container_mean = ForestContainer(num_trees=num_trees_mean, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) + self.forest_container_mean.forest_container_cpp = bart_results["forest_container_mean"] + mean_forest_preds_train = bart_results["mean_forest_predictions_train"] + mean_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") + self.y_hat_train = mean_forest_preds_train * self.y_std + self.y_bar + if self.has_test: mean_forest_preds_test = bart_results["mean_forest_predictions_test"] mean_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") self.y_hat_test = mean_forest_preds_test * self.y_std + self.y_bar + + # Unpack variance forest results if self.include_variance_forest: - variance_forest_preds_train = bart_results["variance_forest_predictions_train"] - variance_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") - self.variance_forest_preds_train = variance_forest_preds_train * self.y_std * self.y_std - if self.has_test: - variance_forest_preds_test = bart_results["variance_forest_predictions_test"] - variance_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") - self.variance_forest_preds_test = variance_forest_preds_test * self.y_std * self.y_std + self.forest_container_variance = ForestContainer(num_trees=num_trees_variance, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) + self.forest_container_variance.forest_container_cpp = bart_results["forest_container_variance"] + variance_forest_preds_train = bart_results["variance_forest_predictions_train"] + variance_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") + self.variance_forest_preds_train = variance_forest_preds_train * self.y_std * self.y_std + if self.has_test: + variance_forest_preds_test = bart_results["variance_forest_predictions_test"] + variance_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") + self.variance_forest_preds_test = variance_forest_preds_test * self.y_std * self.y_std + + # Unpack parameter samples + if sample_sigma2_global: + self.global_var_samples = bart_results["global_var_samples"] * self.y_std * self.y_std + if sample_sigma2_leaf: + self.leaf_scale_samples = bart_results["leaf_scale_samples"] + # Unpack other model metadata self.num_samples = bart_results["num_samples"] self.sampled = True From 312b563480d5267ff0e698606cd2488f03d231e7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 20:00:42 -0400 Subject: [PATCH 35/64] Fix logic inversion bug --- stochtree/bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 0c5498f6..4d59f093 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1115,7 +1115,7 @@ def sample( "beta_mean": beta_mean, "min_samples_leaf_mean": min_samples_leaf_mean, "max_depth_mean": max_depth_mean, - "leaf_constant_mean": True if self.has_basis else False, + "leaf_constant_mean": False if self.has_basis else True, "leaf_dim_mean": self.num_basis if self.has_basis else 1, "exponentiated_leaf_mean": False, "num_features_subsample_mean": num_features_subsample_mean, From fca909a50d55fb64cd19cd003b165aee8d5ad7d5 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 20:00:57 -0400 Subject: [PATCH 36/64] Fix std::move bug with ForestContainerCpp --- src/py_stochtree.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 5c20e422..bcf5a08a 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2311,14 +2311,14 @@ inline py::dict convert_bart_results_to_dict( // Transfer ownership of mean forest pointers if (results_raw.mean_forests != nullptr) { - output["mean_forests"] = py::cast(ForestContainerCpp(std::move(results_raw.mean_forests), config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean)); + output["mean_forests"] = py::cast(std::make_unique(std::move(results_raw.mean_forests), config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean)); } else { output["mean_forests"] = py::none(); } // Transfer ownership of variance forest pointers if (results_raw.variance_forests != nullptr) { - output["variance_forests"] = py::cast(ForestContainerCpp(std::move(results_raw.variance_forests), config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance)); + output["variance_forests"] = py::cast(std::make_unique(std::move(results_raw.variance_forests), config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance)); } else { output["variance_forests"] = py::none(); } From f8d67afc72cbf08969e9aebb88f6ef92945d2a5b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 20:03:25 -0400 Subject: [PATCH 37/64] Fix result unpacking bugs --- src/py_stochtree.cpp | 12 ++++++------ stochtree/bart.py | 19 +++++++++---------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index bcf5a08a..a08f1a11 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2311,16 +2311,16 @@ inline py::dict convert_bart_results_to_dict( // Transfer ownership of mean forest pointers if (results_raw.mean_forests != nullptr) { - output["mean_forests"] = py::cast(std::make_unique(std::move(results_raw.mean_forests), config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean)); + output["forest_container_mean"] = py::cast(std::make_unique(std::move(results_raw.mean_forests), config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean)); } else { - output["mean_forests"] = py::none(); + output["forest_container_mean"] = py::none(); } // Transfer ownership of variance forest pointers if (results_raw.variance_forests != nullptr) { - output["variance_forests"] = py::cast(std::make_unique(std::move(results_raw.variance_forests), config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance)); + output["forest_container_variance"] = py::cast(std::make_unique(std::move(results_raw.variance_forests), config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance)); } else { - output["variance_forests"] = py::none(); + output["forest_container_variance"] = py::none(); } // Move parameter vector samples @@ -2367,12 +2367,12 @@ inline py::dict convert_bart_results_to_dict( // Global error variance samples if (results_raw.global_error_variance_samples.empty()) { - output["global_error_variance_samples"] = py::none(); + output["global_var_samples"] = py::none(); } else { auto input_vec = results_raw.global_error_variance_samples; py::array_t array(input_vec.size()); std::copy(input_vec.begin(), input_vec.end(), array.mutable_data()); - output["global_error_variance_samples"] = array; + output["global_var_samples"] = array; } // Leaf scale samples diff --git a/stochtree/bart.py b/stochtree/bart.py index 4d59f093..a441a2d4 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1184,16 +1184,15 @@ def sample( ) # Unpack mean forest results - if self.include_variance_forest: - self.forest_container_mean = ForestContainer(num_trees=num_trees_mean, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) - self.forest_container_mean.forest_container_cpp = bart_results["forest_container_mean"] - mean_forest_preds_train = bart_results["mean_forest_predictions_train"] - mean_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") - self.y_hat_train = mean_forest_preds_train * self.y_std + self.y_bar - if self.has_test: - mean_forest_preds_test = bart_results["mean_forest_predictions_test"] - mean_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") - self.y_hat_test = mean_forest_preds_test * self.y_std + self.y_bar + self.forest_container_mean = ForestContainer(num_trees=num_trees_mean, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) + self.forest_container_mean.forest_container_cpp = bart_results["forest_container_mean"] + mean_forest_preds_train = bart_results["mean_forest_predictions_train"] + mean_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") + self.y_hat_train = mean_forest_preds_train * self.y_std + self.y_bar + if self.has_test: + mean_forest_preds_test = bart_results["mean_forest_predictions_test"] + mean_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") + self.y_hat_test = mean_forest_preds_test * self.y_std + self.y_bar # Unpack variance forest results if self.include_variance_forest: From b9a6ea6bbdd600d8ac4af93f5c0a3ea2c3321f73 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 20:03:59 -0400 Subject: [PATCH 38/64] Add GHA debugging step to show `00install.out` from R check workflows --- .github/workflows/r-test.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/r-test.yml b/.github/workflows/r-test.yml index f6e3a38c..39fbd242 100644 --- a/.github/workflows/r-test.yml +++ b/.github/workflows/r-test.yml @@ -49,3 +49,7 @@ jobs: - uses: r-lib/actions/check-r-package@v2 with: working-directory: 'stochtree_cran' + + - name: Show install log + run: cat stochtree_cran/check/stochtree.Rcheck/00install.out + shell: bash From bbbdf901bbf4ba0f5b50d167d87593f94c989919 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 20:05:54 -0400 Subject: [PATCH 39/64] Fix unpacking bugs --- stochtree/bart.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index a441a2d4..6e76385f 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1183,20 +1183,26 @@ def sample( config_input = bart_config ) + # Unpack standardization params computed by C++ sampler + self.y_bar = bart_results["y_bar"] + self.y_std = bart_results["y_std"] + # Unpack mean forest results - self.forest_container_mean = ForestContainer(num_trees=num_trees_mean, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) + self.forest_container_mean = ( + ForestContainer(num_trees_mean, 1, True, False) + if not self.has_basis + else ForestContainer(num_trees_mean, self.num_basis, False, False) + ) self.forest_container_mean.forest_container_cpp = bart_results["forest_container_mean"] - mean_forest_preds_train = bart_results["mean_forest_predictions_train"] - mean_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") + mean_forest_preds_train = bart_results["mean_forest_predictions_train"].reshape(self.n_train, bart_results["num_samples"], order="F") self.y_hat_train = mean_forest_preds_train * self.y_std + self.y_bar if self.has_test: - mean_forest_preds_test = bart_results["mean_forest_predictions_test"] - mean_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") + mean_forest_preds_test = bart_results["mean_forest_predictions_test"].reshape(self.n_test, bart_results["num_samples"], order="F") self.y_hat_test = mean_forest_preds_test * self.y_std + self.y_bar # Unpack variance forest results if self.include_variance_forest: - self.forest_container_variance = ForestContainer(num_trees=num_trees_variance, num_samples=num_mcmc, num_burnin=num_burnin, keep_every=keep_every) + self.forest_container_variance = ForestContainer(num_trees_variance, 1, True, True) self.forest_container_variance.forest_container_cpp = bart_results["forest_container_variance"] variance_forest_preds_train = bart_results["variance_forest_predictions_train"] variance_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") From c8f7b4a54ac3ed9e28585060db104f4f26310e66 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 20:09:05 -0400 Subject: [PATCH 40/64] Fix reshape bug --- stochtree/bart.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 6e76385f..c45e63d4 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1204,12 +1204,10 @@ def sample( if self.include_variance_forest: self.forest_container_variance = ForestContainer(num_trees_variance, 1, True, True) self.forest_container_variance.forest_container_cpp = bart_results["forest_container_variance"] - variance_forest_preds_train = bart_results["variance_forest_predictions_train"] - variance_forest_preds_train.reshape(self.n_train, bart_results["num_samples"], order="F") + variance_forest_preds_train = bart_results["variance_forest_predictions_train"].reshape(self.n_train, bart_results["num_samples"], order="F") self.variance_forest_preds_train = variance_forest_preds_train * self.y_std * self.y_std if self.has_test: - variance_forest_preds_test = bart_results["variance_forest_predictions_test"] - variance_forest_preds_test.reshape(self.n_test, bart_results["num_samples"], order="F") + variance_forest_preds_test = bart_results["variance_forest_predictions_test"].reshape(self.n_test, bart_results["num_samples"], order="F") self.variance_forest_preds_test = variance_forest_preds_test * self.y_std * self.y_std # Unpack parameter samples From 2b47fa45cb6a329071884ffb65b1bfbaf8e7f9ab Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 20:17:33 -0400 Subject: [PATCH 41/64] Always show install log --- .github/workflows/r-test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/r-test.yml b/.github/workflows/r-test.yml index 39fbd242..14f810a7 100644 --- a/.github/workflows/r-test.yml +++ b/.github/workflows/r-test.yml @@ -51,5 +51,6 @@ jobs: working-directory: 'stochtree_cran' - name: Show install log + if: always() run: cat stochtree_cran/check/stochtree.Rcheck/00install.out shell: bash From e090dde4185390660f8452c57f3cfdafc3895545 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 20:35:03 -0400 Subject: [PATCH 42/64] Avoid truncation in install log printing --- .github/workflows/r-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/r-test.yml b/.github/workflows/r-test.yml index 14f810a7..71b0abe6 100644 --- a/.github/workflows/r-test.yml +++ b/.github/workflows/r-test.yml @@ -52,5 +52,5 @@ jobs: - name: Show install log if: always() - run: cat stochtree_cran/check/stochtree.Rcheck/00install.out + run: tail -1000 stochtree_cran/check/stochtree.Rcheck/00install.out shell: bash From 6e4310523fa57d7b4e6141bd76c26c0c88e6728d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 23:42:48 -0400 Subject: [PATCH 43/64] Update windows makevars template --- src/Makevars.win.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Makevars.win.in b/src/Makevars.win.in index af8d6f83..d9fa4526 100644 --- a/src/Makevars.win.in +++ b/src/Makevars.win.in @@ -24,7 +24,9 @@ PKG_LIBS = \ OBJECTS = \ forest.o \ kernel.o \ + R_bart.o \ R_data.o \ + bart_sampler.o \ R_random_effects.o \ R_utils.o \ sampler.o \ From 74d7df502e7187b51f7a20c4f6cee138b198db99 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 11 Apr 2026 23:46:28 -0400 Subject: [PATCH 44/64] Remove install log debug statement in R GHA workflow and add two debugging scripts for R and Python --- .github/workflows/r-test.yml | 5 --- tools/debug/debug_cpp_sampler.R | 67 ++++++++++++++++++++++++++++++++ tools/debug/debug_cpp_sampler.py | 53 +++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 5 deletions(-) create mode 100644 tools/debug/debug_cpp_sampler.R create mode 100644 tools/debug/debug_cpp_sampler.py diff --git a/.github/workflows/r-test.yml b/.github/workflows/r-test.yml index 71b0abe6..f6e3a38c 100644 --- a/.github/workflows/r-test.yml +++ b/.github/workflows/r-test.yml @@ -49,8 +49,3 @@ jobs: - uses: r-lib/actions/check-r-package@v2 with: working-directory: 'stochtree_cran' - - - name: Show install log - if: always() - run: tail -1000 stochtree_cran/check/stochtree.Rcheck/00install.out - shell: bash diff --git a/tools/debug/debug_cpp_sampler.R b/tools/debug/debug_cpp_sampler.R new file mode 100644 index 00000000..69d0559f --- /dev/null +++ b/tools/debug/debug_cpp_sampler.R @@ -0,0 +1,67 @@ +################################################################################ +# Minimal script for debugging the C++ sampler under lldb. +# +# Usage (from the repo root): +# lldb -- R --vanilla -f tools/debug/debug_cpp_sampler.R +# # then at the (lldb) prompt: +# # run +# # bt (after the crash, to get a backtrace) +# # frame info (to see the crashing frame) +# +# Alternatively, attach to an already-running R process: +# lldb -p $(pgrep -n R) +################################################################################ + +suppressPackageStartupMessages(devtools::load_all(".")) + +# --- Data generation (mirrors debug_cpp_sampler.py) -------------------------- +seed <- 1001 +n <- 10000 +p <- 10 +set.seed(1234) + +X <- matrix(runif(n * p), nrow = n, ncol = p) +f_X <- ifelse( + X[, 1] < 0.25, + -7.5, + ifelse(X[, 1] < 0.50, -2.5, ifelse(X[, 1] < 0.75, 2.5, 7.5)) +) +y <- f_X + rnorm(n, sd = 1.0) + +n_test <- round(0.2 * n) +test_inds <- sort(sample(seq_len(n), n_test, replace = FALSE)) +train_inds <- setdiff(seq_len(n), test_inds) + +X_train <- X[train_inds, ] +X_test <- X[test_inds, ] +y_train <- y[train_inds] +y_test <- y[test_inds] + +cat(sprintf( + "n_train=%d n_test=%d p=%d seed=%d\n", + length(train_inds), + n_test, + p, + seed +)) +cat("Calling bart() with run_cpp=TRUE ...\n") + +# --- Run C++ sampler ---------------------------------------------------------- +m <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 100, + general_params = list(random_seed = seed), + mean_forest_params = list(num_trees = 200), + run_cpp = TRUE +) + +cat("Completed successfully.\n") +cat(sprintf( + "dim(y_hat_test): %d x %d\n", + nrow(m$y_hat_test), + ncol(m$y_hat_test) +)) diff --git a/tools/debug/debug_cpp_sampler.py b/tools/debug/debug_cpp_sampler.py new file mode 100644 index 00000000..67d22354 --- /dev/null +++ b/tools/debug/debug_cpp_sampler.py @@ -0,0 +1,53 @@ +"""Minimal script for debugging the C++ sampler under lldb. + +Usage: + source venv/bin/activate + lldb -- python debug/debug_cpp_sampler.py + # then at the (lldb) prompt: + # run + # bt (after the crash, to get a backtrace) + # frame info (to see the crashing frame) +""" + +import numpy as np +from stochtree import BARTModel + +seed = 1001 +n = 10000 +p = 10 +rng = np.random.default_rng(1234) + +X = rng.uniform(size=(n, p)) +f_X = ( + np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -7.5, 0) + + np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -2.5, 0) + + np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 2.5, 0) + + np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 7.5, 0) +) +y = f_X + rng.normal(scale=1.0, size=n) + +n_test = round(0.2 * n) +test_inds = rng.choice(n, size=n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) + +X_train, X_test = X[train_inds], X[test_inds] +y_train, y_test = y[train_inds], y[test_inds] + +print(f"n_train={len(train_inds)} n_test={n_test} p={p} seed={seed}") +print("Calling BARTModel.sample() with run_cpp=True ...") + +m = BARTModel() +m.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=10, + num_burnin=0, + num_mcmc=100, + general_params={"random_seed": seed}, + mean_forest_params={"num_trees": 200}, + run_cpp=True, +) + +print("Completed successfully.") +print(f"y_hat_test shape: {m.y_hat_test.shape}") From c7f36953a531f7b11031626761520a29e07d75df Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 12 Apr 2026 19:44:53 -0400 Subject: [PATCH 45/64] Refactored BARTSampler to store config and data references internally --- include/stochtree/bart_sampler.h | 14 ++- include/stochtree/linear_regression.h | 8 +- include/stochtree/probit.h | 2 +- src/R_bart.cpp | 4 +- src/bart_sampler.cpp | 164 +++++++++++++------------- src/py_stochtree.cpp | 4 +- 6 files changed, 101 insertions(+), 95 deletions(-) diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h index 43112580..3e950c9a 100644 --- a/include/stochtree/bart_sampler.h +++ b/include/stochtree/bart_sampler.h @@ -10,8 +10,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -23,18 +25,22 @@ class BARTSampler { BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data); // Main entry point for running the BART sampler, which dispatches to GFR warmup and MCMC sampling functions - void run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, int num_gfr, bool keep_gfr); + void run_gfr(BARTSamples& samples, int num_gfr, bool keep_gfr); // Main entry point for running the BART sampler, which dispatches to GFR warmup and MCMC sampling functions - void run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, int num_burnin, int keep_every, int num_mcmc); + void run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, int num_mcmc); private: /*! Initialize state variables */ - void InitializeState(BARTSamples& samples, BARTConfig& config, BARTData& data); + void InitializeState(BARTSamples& samples); bool initialized_ = false; /*! Internal sample runner function */ - void RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample); + void RunOneIteration(BARTSamples& samples, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample); + + /*! Internal reference to config and data state */ + BARTConfig& config_; + BARTData& data_; /*! Mean forest state */ std::unique_ptr mean_forest_; diff --git a/include/stochtree/linear_regression.h b/include/stochtree/linear_regression.h index c9343f73..28893d5b 100644 --- a/include/stochtree/linear_regression.h +++ b/include/stochtree/linear_regression.h @@ -25,7 +25,7 @@ namespace StochTree { * \param gen Random number generator * \return double */ -inline double sample_univariate_gaussian_regression_coefficient(double* y, double* x, double error_variance, double prior_variance, int n, std::mt19937& gen) { +static double sample_univariate_gaussian_regression_coefficient(double* y, double* x, double error_variance, double prior_variance, int n, std::mt19937& gen) { double sum_xx = 0.0; double sum_yx = 0.0; for (int i = 0; i < n; i++) { @@ -51,7 +51,7 @@ inline double sample_univariate_gaussian_regression_coefficient(double* y, doubl * \param n Number of observations * \param gen Random number generator */ -inline void sample_general_bivariate_gaussian_regression_coefficients(double* output, double* y, double* x1, double* x2, double error_variance, double prior_variance_11, double prior_variance_12, double prior_variance_22, int n, std::mt19937& gen) { +static void sample_general_bivariate_gaussian_regression_coefficients(double* output, double* y, double* x1, double* x2, double error_variance, double prior_variance_11, double prior_variance_12, double prior_variance_22, int n, std::mt19937& gen) { double det_prior_var = prior_variance_11 * prior_variance_22 - prior_variance_12 * prior_variance_12; double inv_prior_var_11 = prior_variance_22 / det_prior_var; double inv_prior_var_12 = -prior_variance_12 / det_prior_var; @@ -99,7 +99,7 @@ inline void sample_general_bivariate_gaussian_regression_coefficients(double* ou * \param n Number of observations * \param gen Random number generator */ -inline void sample_diagonal_bivariate_gaussian_regression_coefficients(double* output, double* y, double* x1, double* x2, double error_variance, double prior_variance_11, double prior_variance_22, int n, std::mt19937& gen) { +static void sample_diagonal_bivariate_gaussian_regression_coefficients(double* output, double* y, double* x1, double* x2, double error_variance, double prior_variance_11, double prior_variance_22, int n, std::mt19937& gen) { double inv_prior_var_11 = 1.0 / prior_variance_11; double inv_prior_var_22 = 1.0 / prior_variance_22; double sum_x1x1 = 0.0; @@ -142,7 +142,7 @@ inline void sample_diagonal_bivariate_gaussian_regression_coefficients(double* o * \param n Number of observations * \param gen Random number generator */ -Eigen::VectorXd sample_general_gaussian_regression_coefficients(Eigen::VectorXd& y, Eigen::MatrixXd& X, double error_variance, Eigen::MatrixXd& prior_variance, int n, std::mt19937& gen) { +static Eigen::VectorXd sample_general_gaussian_regression_coefficients(Eigen::VectorXd& y, Eigen::MatrixXd& X, double error_variance, Eigen::MatrixXd& prior_variance, int n, std::mt19937& gen) { int p = X.cols(); Eigen::MatrixXd inv_prior_var = prior_variance.inverse(); Eigen::MatrixXd XtX = X.transpose() * X; diff --git a/include/stochtree/probit.h b/include/stochtree/probit.h index 300c25ee..24a49252 100644 --- a/include/stochtree/probit.h +++ b/include/stochtree/probit.h @@ -9,7 +9,7 @@ namespace StochTree { -void sample_probit_latent_outcome(std::mt19937& gen, double* outcome, double* conditional_mean, double* partial_residual, double y_bar, int n) { +static void sample_probit_latent_outcome(std::mt19937& gen, double* outcome, double* conditional_mean, double* partial_residual, double y_bar, int n) { double uniform_draw_std; double uniform_draw_trunc; double quantile; diff --git a/src/R_bart.cpp b/src/R_bart.cpp index ff82059f..a521f9a5 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -256,8 +256,8 @@ cpp11::writable::list bart_sample_cpp( StochTree::BARTSampler bart_sampler(results_raw, config, data); // Run the sampler - bart_sampler.run_gfr(results_raw, config, data, num_gfr, true); - bart_sampler.run_mcmc(results_raw, config, data, num_burnin, keep_every, num_mcmc); + bart_sampler.run_gfr(results_raw, num_gfr, true); + bart_sampler.run_mcmc(results_raw, num_burnin, keep_every, num_mcmc); // Unprotect protected R objects UNPROTECT(protect_count); diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 8aa03b60..7fe5292e 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -12,51 +12,51 @@ namespace StochTree { -BARTSampler::BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data) { - InitializeState(samples, config, data); +BARTSampler::BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data) : config_{config}, data_{data} { + InitializeState(samples); } -void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BARTData& data) { +void BARTSampler::InitializeState(BARTSamples& samples) { // Load data from BARTData object into ForestDataset object forest_dataset_ = std::make_unique(); - forest_dataset_->AddCovariates(data.X_train, data.n_train, data.p, /*row_major=*/false); - if (data.basis_train != nullptr) { - forest_dataset_->AddBasis(data.basis_train, data.n_train, data.basis_dim, /*row_major=*/false); - } - if (data.obs_weights_train != nullptr) { - forest_dataset_->AddVarianceWeights(data.obs_weights_train, data.n_train); - } - samples.num_train = data.n_train; - samples.num_test = data.n_test; - residual_ = std::make_unique(data.y_train, data.n_train); - outcome_raw_ = std::make_unique(data.y_train, data.n_train); - if (data.X_test != nullptr) { + forest_dataset_->AddCovariates(data_.X_train, data_.n_train, data_.p, /*row_major=*/false); + if (data_.basis_train != nullptr) { + forest_dataset_->AddBasis(data_.basis_train, data_.n_train, data_.basis_dim, /*row_major=*/false); + } + if (data_.obs_weights_train != nullptr) { + forest_dataset_->AddVarianceWeights(data_.obs_weights_train, data_.n_train); + } + samples.num_train = data_.n_train; + samples.num_test = data_.n_test; + residual_ = std::make_unique(data_.y_train, data_.n_train); + outcome_raw_ = std::make_unique(data_.y_train, data_.n_train); + if (data_.X_test != nullptr) { forest_dataset_test_ = std::make_unique(); - forest_dataset_test_->AddCovariates(data.X_test, data.n_test, data.p, /*row_major=*/false); - if (data.basis_test != nullptr) { - forest_dataset_test_->AddBasis(data.basis_test, data.n_test, data.basis_dim, /*row_major=*/false); + forest_dataset_test_->AddCovariates(data_.X_test, data_.n_test, data_.p, /*row_major=*/false); + if (data_.basis_test != nullptr) { + forest_dataset_test_->AddBasis(data_.basis_test, data_.n_test, data_.basis_dim, /*row_major=*/false); } - if (data.obs_weights_test != nullptr) { - forest_dataset_test_->AddVarianceWeights(data.obs_weights_test, data.n_test); + if (data_.obs_weights_test != nullptr) { + forest_dataset_test_->AddVarianceWeights(data_.obs_weights_test, data_.n_test); } has_test_ = true; } // Precompute outcome mean and variance for standardization and calibration double y_mean = 0.0, M2 = 0.0, y_mean_prev = 0.0; - for (int i = 0; i < data.n_train; i++) { + for (int i = 0; i < data_.n_train; i++) { y_mean_prev = y_mean; - y_mean = y_mean_prev + (data.y_train[i] - y_mean_prev) / (i + 1); - M2 = M2 + (data.y_train[i] - y_mean_prev) * (data.y_train[i] - y_mean); + y_mean = y_mean_prev + (data_.y_train[i] - y_mean_prev) / (i + 1); + M2 = M2 + (data_.y_train[i] - y_mean_prev) * (data_.y_train[i] - y_mean); } - double y_var = M2 / data.n_train; + double y_var = M2 / data_.n_train; // Compute outcome location and scale for standardization - if (config.link_function == LinkFunction::Probit) { + if (config_.link_function == LinkFunction::Probit) { samples.y_std = 1.0; samples.y_bar = norm_inv_cdf(y_mean); } else { - if (config.standardize_outcome) { + if (config_.standardize_outcome) { samples.y_bar = y_mean; samples.y_std = std::sqrt(y_var); } else { @@ -66,120 +66,120 @@ void BARTSampler::InitializeState(BARTSamples& samples, BARTConfig& config, BART } // Standardize partial residuals in place; these are updated in each iteration but initialized to standardized outcomes - for (int i = 0; i < data.n_train; i++) residual_->GetData()[i] = (data.y_train[i] - samples.y_bar) / samples.y_std; + for (int i = 0; i < data_.n_train; i++) residual_->GetData()[i] = (data_.y_train[i] - samples.y_bar) / samples.y_std; // Initialize mean forest state (if present) - if (config.num_trees_mean > 0) { - mean_forest_ = std::make_unique(config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean); - samples.mean_forests = std::make_unique(config.num_trees_mean, config.leaf_dim_mean, config.leaf_constant_mean, config.exponentiated_leaf_mean); - mean_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config.feature_types, config.num_trees_mean, data.n_train); - tree_prior_mean_ = std::make_unique(config.alpha_mean, config.beta_mean, config.min_samples_leaf_mean, config.max_depth_mean); + if (config_.num_trees_mean > 0) { + mean_forest_ = std::make_unique(config_.num_trees_mean, config_.leaf_dim_mean, config_.leaf_constant_mean, config_.exponentiated_leaf_mean); + samples.mean_forests = std::make_unique(config_.num_trees_mean, config_.leaf_dim_mean, config_.leaf_constant_mean, config_.exponentiated_leaf_mean); + mean_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_mean, data_.n_train); + tree_prior_mean_ = std::make_unique(config_.alpha_mean, config_.beta_mean, config_.min_samples_leaf_mean, config_.max_depth_mean); mean_forest_->SetLeafValue(0.0); - UpdateResidualEntireForest(*mean_forest_tracker_, *forest_dataset_, *residual_, mean_forest_.get(), !config.leaf_constant_mean, std::minus()); + UpdateResidualEntireForest(*mean_forest_tracker_, *forest_dataset_, *residual_, mean_forest_.get(), !config_.leaf_constant_mean, std::minus()); mean_forest_tracker_->UpdatePredictions(mean_forest_.get(), *forest_dataset_.get()); has_mean_forest_ = true; - if (config.sigma2_mean_init < 0.0) { - if (config.link_function == LinkFunction::Probit) { - config.sigma2_mean_init = 1.0 / config.num_trees_mean; + if (config_.sigma2_mean_init < 0.0) { + if (config_.link_function == LinkFunction::Probit) { + config_.sigma2_mean_init = 1.0 / config_.num_trees_mean; } else { - config.sigma2_mean_init = y_var / config.num_trees_mean; + config_.sigma2_mean_init = y_var / config_.num_trees_mean; } } - if (config.sample_sigma2_leaf_mean) { - if (config.b_sigma2_mean <= 0.0) { - if (config.link_function == LinkFunction::Probit) { - config.b_sigma2_mean = 1.0 / (2 * config.num_trees_mean); + if (config_.sample_sigma2_leaf_mean) { + if (config_.b_sigma2_mean <= 0.0) { + if (config_.link_function == LinkFunction::Probit) { + config_.b_sigma2_mean = 1.0 / (2 * config_.num_trees_mean); } else { - config.b_sigma2_mean = y_var / (2 * config.num_trees_mean); + config_.b_sigma2_mean = y_var / (2 * config_.num_trees_mean); } } } } // Initialize variance forest state (if present) - if (config.num_trees_variance > 0) { - variance_forest_ = std::make_unique(config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance); - samples.variance_forests = std::make_unique(config.num_trees_variance, config.leaf_dim_variance, config.leaf_constant_variance, config.exponentiated_leaf_variance); - variance_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config.feature_types, config.num_trees_variance, data.n_train); - tree_prior_variance_ = std::make_unique(config.alpha_variance, config.beta_variance, config.min_samples_leaf_variance, config.max_depth_variance); - variance_forest_->SetLeafValue(1.0 / config.num_trees_variance); + if (config_.num_trees_variance > 0) { + variance_forest_ = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); + samples.variance_forests = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); + variance_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_variance, data_.n_train); + tree_prior_variance_ = std::make_unique(config_.alpha_variance, config_.beta_variance, config_.min_samples_leaf_variance, config_.max_depth_variance); + variance_forest_->SetLeafValue(1.0 / config_.num_trees_variance); variance_forest_tracker_->UpdatePredictions(variance_forest_.get(), *forest_dataset_.get()); has_variance_forest_ = true; - if (config.shape_variance_forest <= 0.0 || config.scale_variance_forest <= 0.0) { - if (config.leaf_prior_calibration_param <= 0.0) { - config.leaf_prior_calibration_param = 1.5; + if (config_.shape_variance_forest <= 0.0 || config_.scale_variance_forest <= 0.0) { + if (config_.leaf_prior_calibration_param <= 0.0) { + config_.leaf_prior_calibration_param = 1.5; } - if (config.shape_variance_forest <= 0.0) { - config.shape_variance_forest = config.num_trees_variance / (config.leaf_prior_calibration_param * config.leaf_prior_calibration_param) + 0.5; + if (config_.shape_variance_forest <= 0.0) { + config_.shape_variance_forest = config_.num_trees_variance / (config_.leaf_prior_calibration_param * config_.leaf_prior_calibration_param) + 0.5; } - if (config.scale_variance_forest <= 0.0) { - config.scale_variance_forest = config.num_trees_variance / (config.leaf_prior_calibration_param * config.leaf_prior_calibration_param); + if (config_.scale_variance_forest <= 0.0) { + config_.scale_variance_forest = config_.num_trees_variance / (config_.leaf_prior_calibration_param * config_.leaf_prior_calibration_param); } } } // Global error variance model - if (config.sample_sigma2_global) { + if (config_.sample_sigma2_global) { var_model_ = std::make_unique(); sample_sigma2_global_ = true; } // Leaf scale model - if (config.sample_sigma2_leaf_mean) { + if (config_.sample_sigma2_leaf_mean) { leaf_scale_model_ = std::make_unique(); sample_sigma2_leaf_ = true; } // RNG - rng_ = std::mt19937(config.random_seed >= 0 ? config.random_seed : std::random_device{}()); + rng_ = std::mt19937(config_.random_seed >= 0 ? config_.random_seed : std::random_device{}()); // Other internal model state - global_variance_ = config.sigma2_global_init; - leaf_scale_ = config.sigma2_mean_init; - // leaf_scale_multivariate_ = config.sigma2_leaf_multivariate_init; + global_variance_ = config_.sigma2_global_init; + leaf_scale_ = config_.sigma2_mean_init; + // leaf_scale_multivariate_ = config_.sigma2_leaf_multivariate_init; initialized_ = true; } -void BARTSampler::run_gfr(BARTSamples& samples, BARTConfig& config, BARTData& data, int num_gfr, bool keep_gfr) { +void BARTSampler::run_gfr(BARTSamples& samples, int num_gfr, bool keep_gfr) { // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); - std::unique_ptr variance_leaf_model_ptr = std::make_unique(config.shape_variance_forest, config.scale_variance_forest); + std::unique_ptr variance_leaf_model_ptr = std::make_unique(config_.shape_variance_forest, config_.scale_variance_forest); for (int i = 0; i < num_gfr; i++) { - RunOneIteration(samples, config, data, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/true, /*keep_sample=*/keep_gfr); + RunOneIteration(samples, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/true, /*keep_sample=*/keep_gfr); } } -void BARTSampler::run_mcmc(BARTSamples& samples, BARTConfig& config, BARTData& data, int num_burnin, int keep_every, int num_mcmc) { +void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, int num_mcmc) { std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); - std::unique_ptr variance_leaf_model_ptr = std::make_unique(config.shape_variance_forest, config.scale_variance_forest); + std::unique_ptr variance_leaf_model_ptr = std::make_unique(config_.shape_variance_forest, config_.scale_variance_forest); bool keep_forest = false; for (int i = 0; i < num_burnin + keep_every * num_mcmc; i++) { if (i >= num_burnin && (i - num_burnin) % keep_every == 0) keep_forest = true; else keep_forest = false; - RunOneIteration(samples, config, data, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/false, /*keep_sample=*/keep_forest); + RunOneIteration(samples, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/false, /*keep_sample=*/keep_forest); } } -void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BARTData& data, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample) { +void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample) { if (has_mean_forest_) { if (gfr) { GFRSampleOneIter( *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, *forest_dataset_, *residual_, *tree_prior_mean_, rng_, - config.var_weights_mean, config.sweep_update_indices_mean, global_variance_, config.feature_types, - config.cutpoint_grid_size, /*keep_forest=*/keep_sample, + config_.var_weights_mean, config_.sweep_update_indices_mean, global_variance_, config_.feature_types, + config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, - /*num_features_subsample=*/data.p, config.num_threads); + /*num_features_subsample=*/data_.p, config_.num_threads); } else { MCMCSampleOneIter( *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, *forest_dataset_, *residual_, *tree_prior_mean_, rng_, - config.var_weights_mean, config.sweep_update_indices_mean, global_variance_, /*keep_forest=*/keep_sample, + config_.var_weights_mean, config_.sweep_update_indices_mean, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, - /*num_threads=*/config.num_threads); + /*num_threads=*/config_.num_threads); } } @@ -188,33 +188,33 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, BARTConfig& config, BART GFRSampleOneIter( *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, *forest_dataset_, *residual_, *tree_prior_variance_, rng_, - config.var_weights_variance, config.sweep_update_indices_variance, global_variance_, config.feature_types, - config.cutpoint_grid_size, /*keep_forest=*/keep_sample, + config_.var_weights_variance, config_.sweep_update_indices_variance, global_variance_, config_.feature_types, + config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, - /*num_features_subsample=*/data.p, config.num_threads); + /*num_features_subsample=*/data_.p, config_.num_threads); } else { MCMCSampleOneIter( *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, *forest_dataset_, *residual_, *tree_prior_variance_, rng_, - config.var_weights_variance, config.sweep_update_indices_variance, global_variance_, /*keep_forest=*/keep_sample, + config_.var_weights_variance, config_.sweep_update_indices_variance, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, - /*num_threads=*/config.num_threads); + /*num_threads=*/config_.num_threads); } } - if (config.link_function == LinkFunction::Probit) { + if (config_.link_function == LinkFunction::Probit) { sample_probit_latent_outcome(rng_, outcome_raw_->GetData().data(), mean_forest_tracker_->GetSumPredictions(), - residual_->GetData().data(), samples.y_bar, data.n_train); + residual_->GetData().data(), samples.y_bar, data_.n_train); } if (sample_sigma2_global_) { global_variance_ = var_model_->SampleVarianceParameter( - residual_->GetData(), config.a_sigma2_global, config.b_sigma2_global, rng_); + residual_->GetData(), config_.a_sigma2_global, config_.b_sigma2_global, rng_); } if (sample_sigma2_leaf_) { leaf_scale_ = leaf_scale_model_->SampleVarianceParameter( - mean_forest_.get(), config.a_sigma2_mean, config.b_sigma2_mean, rng_); + mean_forest_.get(), config_.a_sigma2_mean, config_.b_sigma2_mean, rng_); mean_leaf_model->SetScale(leaf_scale_); } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index a08f1a11..f1657fc8 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2431,8 +2431,8 @@ py::dict bart_sample_cpp( StochTree::BARTSampler bart_sampler(bart_results_raw, bart_config, bart_data); // Run the sampler - bart_sampler.run_gfr(bart_results_raw, bart_config, bart_data, num_gfr, true); - bart_sampler.run_mcmc(bart_results_raw, bart_config, bart_data, num_burnin, keep_every, num_mcmc); + bart_sampler.run_gfr(bart_results_raw, num_gfr, true); + bart_sampler.run_mcmc(bart_results_raw, num_burnin, keep_every, num_mcmc); // Convert results to Python dictionary return convert_bart_results_to_dict(bart_results_raw, bart_config); From 0bc70b2fbf43c0a2bcb3fd6ca84dfb8fec53a1fa Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 12 Apr 2026 19:51:07 -0400 Subject: [PATCH 46/64] Pre-reserve forest predictions --- src/bart_sampler.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 7fe5292e..5c24336d 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -148,12 +148,26 @@ void BARTSampler::run_gfr(BARTSamples& samples, int num_gfr, bool keep_gfr) { for (int i = 0; i < num_gfr; i++) { RunOneIteration(samples, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/true, /*keep_sample=*/keep_gfr); } + if (keep_gfr) { + if (has_mean_forest_) { + samples.mean_forest_predictions_train.reserve(data_.n_train * num_gfr); + } + if (has_variance_forest_) { + samples.variance_forest_predictions_train.reserve(data_.n_train * num_gfr); + } + } } void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, int num_mcmc) { std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); std::unique_ptr variance_leaf_model_ptr = std::make_unique(config_.shape_variance_forest, config_.scale_variance_forest); bool keep_forest = false; + if (has_mean_forest_) { + samples.mean_forest_predictions_train.reserve(data_.n_train * num_mcmc); + } + if (has_variance_forest_) { + samples.variance_forest_predictions_train.reserve(data_.n_train * num_mcmc); + } for (int i = 0; i < num_burnin + keep_every * num_mcmc; i++) { if (i >= num_burnin && (i - num_burnin) % keep_every == 0) keep_forest = true; From 8e5410fb58eedff725fae1f49cf8a30a3afbfd7d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 12 Apr 2026 20:20:21 -0400 Subject: [PATCH 47/64] Run test set predictions after the fact --- debug/bart_debug.cpp | 5 +++-- include/stochtree/bart_sampler.h | 7 +++++-- src/R_bart.cpp | 1 + src/bart_sampler.cpp | 21 ++++++++++++++++----- src/py_stochtree.cpp | 1 + 5 files changed, 26 insertions(+), 9 deletions(-) diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index b8122cfa..e00dc458 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -131,8 +131,9 @@ static void run_scenario_0(int n, int n_test, int p, int num_trees, int num_gfr, StochTree::BARTSamples samples; StochTree::BARTSampler sampler(samples, config, data); - sampler.run_gfr(samples, config, data, num_gfr, true); - sampler.run_mcmc(samples, config, data, 0, 1, num_mcmc); + sampler.run_gfr(samples, num_gfr, true); + sampler.run_mcmc(samples, 0, 1, num_mcmc); + sampler.postprocess_samples(samples); report_bart(samples, test.y, "Scenario 0 (Homoskedastic BART)"); } diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h index 3e950c9a..14fb1bb9 100644 --- a/include/stochtree/bart_sampler.h +++ b/include/stochtree/bart_sampler.h @@ -24,12 +24,15 @@ class BARTSampler { public: BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data); - // Main entry point for running the BART sampler, which dispatches to GFR warmup and MCMC sampling functions + // Main entry point for running the BART GFR sampler void run_gfr(BARTSamples& samples, int num_gfr, bool keep_gfr); - // Main entry point for running the BART sampler, which dispatches to GFR warmup and MCMC sampling functions + // Main entry point for running the BART MCMC sampler void run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, int num_mcmc); + // Post-process samples by extracting test set predictions and running any necessary transformations + void postprocess_samples(BARTSamples& samples); + private: /*! Initialize state variables */ void InitializeState(BARTSamples& samples); diff --git a/src/R_bart.cpp b/src/R_bart.cpp index a521f9a5..51e5ebec 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -258,6 +258,7 @@ cpp11::writable::list bart_sample_cpp( // Run the sampler bart_sampler.run_gfr(results_raw, num_gfr, true); bart_sampler.run_mcmc(results_raw, num_burnin, keep_every, num_mcmc); + bart_sampler.postprocess_samples(results_raw); // Unprotect protected R objects UNPROTECT(protect_count); diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 5c24336d..7fb8a7af 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -164,9 +164,15 @@ void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, bool keep_forest = false; if (has_mean_forest_) { samples.mean_forest_predictions_train.reserve(data_.n_train * num_mcmc); + if (has_test_) { + samples.mean_forest_predictions_test.reserve(data_.n_test * num_mcmc); + } } if (has_variance_forest_) { samples.variance_forest_predictions_train.reserve(data_.n_train * num_mcmc); + if (has_test_) { + samples.variance_forest_predictions_train.reserve(data_.n_test * num_mcmc); + } } for (int i = 0; i < num_burnin + keep_every * num_mcmc; i++) { if (i >= num_burnin && (i - num_burnin) % keep_every == 0) @@ -177,6 +183,16 @@ void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, } } +void BARTSampler::postprocess_samples(BARTSamples& samples) { + if (has_mean_forest_) { + if (has_test_) { + std::vector predictions = samples.mean_forests->Predict(*forest_dataset_test_); + samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), + predictions.data(), predictions.data() + predictions.size()); + } + } +} + void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample) { if (has_mean_forest_) { if (gfr) { @@ -240,11 +256,6 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafMode double* mean_forest_preds_train = mean_forest_tracker_->GetSumPredictions(); samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), mean_forest_preds_train, mean_forest_preds_train + samples.num_train); - if (has_test_) { - std::vector predictions = samples.mean_forests->GetEnsemble(samples.num_samples - 1)->Predict(*forest_dataset_test_); - samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), - predictions.data(), predictions.data() + samples.num_test); - } } } } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index f1657fc8..a00799e8 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2433,6 +2433,7 @@ py::dict bart_sample_cpp( // Run the sampler bart_sampler.run_gfr(bart_results_raw, num_gfr, true); bart_sampler.run_mcmc(bart_results_raw, num_burnin, keep_every, num_mcmc); + bart_sampler.postprocess_samples(bart_results_raw); // Convert results to Python dictionary return convert_bart_results_to_dict(bart_results_raw, bart_config); From 9667ec3d22313957ef0992030d2c70c33a08feb2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 12 Apr 2026 20:31:35 -0400 Subject: [PATCH 48/64] Added keep_gfr and keep_burnin to bart C++ interface --- R/bart.R | 2 ++ include/stochtree/bart.h | 2 ++ src/R_bart.cpp | 2 ++ src/py_stochtree.cpp | 2 ++ stochtree/bart.py | 2 ++ 5 files changed, 10 insertions(+) diff --git a/R/bart.R b/R/bart.R index c58493cd..cf16b6b7 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1107,6 +1107,8 @@ bart <- function( ifelse(outcome_model$outcome == "binary", 1, 2) ), "random_seed" = random_seed, + "keep_gfr" = keep_gfr, + "keep_burnin" = keep_burnin, "a_sigma2_global" = a_global, "b_sigma2_global" = b_global, "sigma2_global_init" = sigma2_init, diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index 31562250..66fc9c66 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -64,6 +64,8 @@ struct BARTConfig { LinkFunction link_function = LinkFunction::Identity; // link function to use (Identity, Probit, Cloglog) OutcomeType outcome_type = OutcomeType::Continuous; // type of the outcome variable (Continuous, Binary, Ordinal) int random_seed = -1; // random seed for reproducibility (if negative, a random seed will be generated) + bool keep_gfr = true; // whether or not to keep GFR samples or simply use them to warm-start an MCMC chain + bool keep_burnin = false; // whether or not to keep "burn-in" MCMC samples (largely a debugging flag) // Global error variance parameters double a_sigma2_global = 0.0; // shape parameter for inverse gamma prior on global error variance diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 51e5ebec..70b47b7a 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -59,6 +59,8 @@ StochTree::BARTConfig convert_list_to_config(cpp11::list config) { output.link_function = static_cast(get_config_scalar_default(config, "link_function", 0)); output.outcome_type = static_cast(get_config_scalar_default(config, "outcome_type", 0)); output.random_seed = get_config_scalar_default(config, "random_seed", 1); + output.keep_gfr = get_config_scalar_default(config, "keep_gfr", true); + output.keep_burnin = get_config_scalar_default(config, "keep_burnin", false); // Global error variance parameters output.a_sigma2_global = get_config_scalar_default(config, "a_sigma2_global", 0.0); diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index a00799e8..4ab96431 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2175,6 +2175,8 @@ inline StochTree::BARTConfig convert_dict_to_bart_config(py::dict config_dict) { output.link_function = static_cast(get_config_scalar_default(config_dict, "link_function", 0)); output.outcome_type = static_cast(get_config_scalar_default(config_dict, "outcome_type", 0)); output.random_seed = get_config_scalar_default(config_dict, "random_seed", 1); + output.keep_gfr = get_config_scalar_default(config_dict, "keep_gfr", 1); + output.keep_burnin = get_config_scalar_default(config_dict, "keep_burnin", 1); // Global error variance parameters output.a_sigma2_global = get_config_scalar_default(config_dict, "a_sigma2_global", 0.0); diff --git a/stochtree/bart.py b/stochtree/bart.py index c45e63d4..91330d76 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1106,6 +1106,8 @@ def sample( "link_function": 0 if self.outcome_model.link == "identity" else (1 if self.outcome_model.link == "probit" else 2), "outcome_type": 0 if self.outcome_model.outcome == "continuous" else (1 if self.outcome_model.outcome == "binary" else 2), "random_seed": random_seed, + "keep_gfr": keep_gfr, + "keep_burnin": keep_burnin, "a_sigma2_global": a_global, "b_sigma2_global": b_global, "sigma2_global_init": 1.0, # TODO: calibrate this before From a5163913be6e67d6e95122fafe5ddcf1df1c7731 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 12 Apr 2026 20:36:42 -0400 Subject: [PATCH 49/64] Pass through num_features_subsample to GFR in C++ --- src/bart_sampler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 7fb8a7af..f670240e 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -202,7 +202,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafMode config_.var_weights_mean, config_.sweep_update_indices_mean, global_variance_, config_.feature_types, config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/true, - /*num_features_subsample=*/data_.p, config_.num_threads); + /*num_features_subsample=*/config_.num_features_subsample_mean, config_.num_threads); } else { MCMCSampleOneIter( *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, @@ -221,7 +221,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafMode config_.var_weights_variance, config_.sweep_update_indices_variance, global_variance_, config_.feature_types, config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, - /*num_features_subsample=*/data_.p, config_.num_threads); + /*num_features_subsample=*/config_.num_features_subsample_variance, config_.num_threads); } else { MCMCSampleOneIter( *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, From 1ddc788aa54b1a159190785c6d93ca96399bdd37 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 13 Apr 2026 13:31:58 -0500 Subject: [PATCH 50/64] Propagate keep_gfr and fix bugs in the bart_debug program --- debug/bart_debug.cpp | 4 ++-- src/R_bart.cpp | 2 +- src/py_stochtree.cpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index e00dc458..1223179c 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -164,8 +164,8 @@ static void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, StochTree::BARTSamples samples; StochTree::BARTSampler sampler(samples, config, data); - sampler.run_gfr(samples, config, data, num_gfr, true); - sampler.run_mcmc(samples, config, data, 0, 1, num_mcmc); + sampler.run_gfr(samples, num_gfr, true); + sampler.run_mcmc(samples, 0, 1, num_mcmc); // Predictions are on latent scale (= raw + y_bar); compare to true latent Z. report_bart(samples, test.Z, "Scenario 1 (Probit BART)"); } diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 70b47b7a..31f75a23 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -258,7 +258,7 @@ cpp11::writable::list bart_sample_cpp( StochTree::BARTSampler bart_sampler(results_raw, config, data); // Run the sampler - bart_sampler.run_gfr(results_raw, num_gfr, true); + bart_sampler.run_gfr(results_raw, num_gfr, config.keep_gfr); bart_sampler.run_mcmc(results_raw, num_burnin, keep_every, num_mcmc); bart_sampler.postprocess_samples(results_raw); diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 4ab96431..21d48913 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2433,7 +2433,7 @@ py::dict bart_sample_cpp( StochTree::BARTSampler bart_sampler(bart_results_raw, bart_config, bart_data); // Run the sampler - bart_sampler.run_gfr(bart_results_raw, num_gfr, true); + bart_sampler.run_gfr(bart_results_raw, num_gfr, bart_config.keep_gfr); bart_sampler.run_mcmc(bart_results_raw, num_burnin, keep_every, num_mcmc); bart_sampler.postprocess_samples(bart_results_raw); From 4be1e9c2a6a0248bf59f20908b4e5c9eecd7efae Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 14 Apr 2026 16:38:22 -0500 Subject: [PATCH 51/64] Make probit and predict work in R --- R/bart.R | 184 ++++++++++++++------ debug/benchmark_cpp_vs_py_sampler_probit.py | 144 +++++++++++++++ debug/benchmark_cpp_vs_r_sampler_probit.R | 179 +++++++++++++++++++ src/R_bart.cpp | 18 +- src/bart_sampler.cpp | 122 ++++++++----- 5 files changed, 546 insertions(+), 101 deletions(-) create mode 100644 debug/benchmark_cpp_vs_py_sampler_probit.py create mode 100644 debug/benchmark_cpp_vs_r_sampler_probit.R diff --git a/R/bart.R b/R/bart.R index cf16b6b7..67335feb 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1090,6 +1090,44 @@ bart <- function( leaf_regression = FALSE } + model_params_r <- list( + "a_global" = a_global, + "b_global" = b_global, + "a_leaf" = a_leaf, + "standardize" = standardize, + "leaf_dimension" = leaf_dimension, + "is_leaf_constant" = is_leaf_constant, + "leaf_regression" = leaf_regression, + "requires_basis" = leaf_regression, + "num_covariates" = num_cov_orig, + "num_basis" = ifelse( + is.null(leaf_basis_train), + 0, + ncol(leaf_basis_train) + ), + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "keep_every" = keep_every, + "num_chains" = num_chains, + "has_basis" = !is.null(leaf_basis_train), + "has_rfx" = has_rfx, + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx, + "sample_sigma2_global" = sample_sigma2_global, + "sample_sigma2_leaf" = sample_sigma2_leaf, + "include_mean_forest" = include_mean_forest, + "include_variance_forest" = include_variance_forest, + "outcome_model" = outcome_model, + "probit_outcome_model" = probit_outcome_model, + "cloglog_num_categories" = ifelse( + link_is_cloglog, + max(y_train - min(y_train)) + 1, + 0 + ), + "rfx_model_spec" = rfx_model_spec + ) + if (run_cpp) { # Specify the BART config bart_config <- list( @@ -1205,28 +1243,102 @@ bart <- function( num_mcmc = as.integer(num_mcmc), config_input = bart_config ) - result <- bart_results - # TODO: store num_samples in the result list - if (!is.null(result['mean_forest_predictions_train'])) { - dim(result[['mean_forest_predictions_train']]) <- c( - result[["num_train"]], - result[["num_samples"]] + result <- list() + model_params_cpp <- list( + "sigma2_init" = bart_results[["sigma2_global_init"]], + "sigma2_leaf_init" = bart_results[["sigma2_mean_init"]], + "b_leaf" = bart_results[["b_sigma2_mean"]], + "a_forest" = bart_results[["shape_variance_forest"]], + "b_forest" = bart_results[["scale_variance_forest"]], + "outcome_mean" = bart_results[["y_bar"]], + "outcome_scale" = bart_results[["y_std"]], + "num_samples" = bart_results[["num_samples"]] + ) + model_params <- c(model_params_r, model_params_cpp) + result[["model_params"]] <- model_params + result[["train_set_metadata"]] <- X_train_metadata + + # Unpack mean forest predictions if they were returned + has_mean_forest_predictions_train <- !is.null( + bart_results[['mean_forest_predictions_train']] + ) + has_mean_forest_predictions_test <- !is.null( + bart_results[['mean_forest_predictions_test']] + ) + if (has_mean_forest_predictions_train) { + dim(bart_results[['mean_forest_predictions_train']]) <- c( + bart_results[["num_train"]], + bart_results[["num_samples"]] ) - y_hat_train_raw <- result[["mean_forest_predictions_train"]] + y_hat_train_raw <- bart_results[["mean_forest_predictions_train"]] result[["y_hat_train"]] <- y_hat_train_raw * - result[["y_std"]] + - result[["y_bar"]] + bart_results[["y_std"]] + + bart_results[["y_bar"]] } - if (!is.null(result['mean_forest_predictions_test'])) { - dim(result[['mean_forest_predictions_test']]) <- c( - result[["num_test"]], - result[["num_samples"]] + if (has_mean_forest_predictions_test) { + dim(bart_results[['mean_forest_predictions_test']]) <- c( + bart_results[["num_test"]], + bart_results[["num_samples"]] ) - y_hat_test_raw <- result[["mean_forest_predictions_test"]] + y_hat_test_raw <- bart_results[["mean_forest_predictions_test"]] result[["y_hat_test"]] <- y_hat_test_raw * - result[["y_std"]] + - result[["y_bar"]] + bart_results[["y_std"]] + + bart_results[["y_bar"]] + } + if (has_mean_forest_predictions_train || has_mean_forest_predictions_test) { + mean_forests_r <- ForestSamples$new( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) + mean_forests_r$forest_container_ptr <- bart_results[[ + "mean_forests" + ]] + result[["mean_forests"]] <- mean_forests_r } + + # Unpack variance forest predictions if they were returned + has_variance_forest_predictions_train <- !is.null( + bart_results[['variance_forest_predictions_train']] + ) + has_variance_forest_predictions_test <- !is.null( + bart_results[['variance_forest_predictions_test']] + ) + if (has_variance_forest_predictions_train) { + dim(bart_results[['variance_forest_predictions_train']]) <- c( + bart_results[["num_train"]], + bart_results[["num_samples"]] + ) + result[["sigma2_x_hat_train"]] <- bart_results[[ + "variance_forest_predictions_train" + ]] + } + if (has_variance_forest_predictions_test) { + dim(bart_results[['variance_forest_predictions_test']]) <- c( + bart_results[["num_test"]], + bart_results[["num_samples"]] + ) + result[["sigma2_x_hat_test"]] <- bart_results[[ + "variance_forest_predictions_test" + ]] + } + if ( + has_variance_forest_predictions_train || + has_variance_forest_predictions_test + ) { + variance_forests_r <- ForestSamples$new( + num_trees_variance, + 1, + FALSE, + TRUE + ) + variance_forests_r$forest_container_ptr <- bart_results[[ + "variance_forests" + ]] + result[["variance_forests"]] <- variance_forests_r + } + class(result) <- "bartmodel" } else { # Set a function-scoped RNG if user provided a random seed @@ -2550,51 +2662,17 @@ bart <- function( } # Return results as a list - model_params <- list( + model_params_r_calibrated <- list( "sigma2_init" = sigma2_init, "sigma2_leaf_init" = sigma2_leaf_init, - "a_global" = a_global, - "b_global" = b_global, - "a_leaf" = a_leaf, "b_leaf" = b_leaf, "a_forest" = a_forest, "b_forest" = b_forest, "outcome_mean" = y_bar_train, "outcome_scale" = y_std_train, - "standardize" = standardize, - "leaf_dimension" = leaf_dimension, - "is_leaf_constant" = is_leaf_constant, - "leaf_regression" = leaf_regression, - "requires_basis" = requires_basis, - "num_covariates" = num_cov_orig, - "num_basis" = ifelse( - is.null(leaf_basis_train), - 0, - ncol(leaf_basis_train) - ), - "num_samples" = num_retained_samples, - "num_gfr" = num_gfr, - "num_burnin" = num_burnin, - "num_mcmc" = num_mcmc, - "keep_every" = keep_every, - "num_chains" = num_chains, - "has_basis" = !is.null(leaf_basis_train), - "has_rfx" = has_rfx, - "has_rfx_basis" = has_basis_rfx, - "num_rfx_basis" = num_basis_rfx, - "sample_sigma2_global" = sample_sigma2_global, - "sample_sigma2_leaf" = sample_sigma2_leaf, - "include_mean_forest" = include_mean_forest, - "include_variance_forest" = include_variance_forest, - "outcome_model" = outcome_model, - "probit_outcome_model" = probit_outcome_model, - "cloglog_num_categories" = ifelse( - link_is_cloglog, - cloglog_num_categories, - 0 - ), - "rfx_model_spec" = rfx_model_spec + "num_samples" = num_retained_samples ) + model_params <- c(model_params_r, model_params_r_calibrated) result <- list( "model_params" = model_params, "train_set_metadata" = X_train_metadata diff --git a/debug/benchmark_cpp_vs_py_sampler_probit.py b/debug/benchmark_cpp_vs_py_sampler_probit.py new file mode 100644 index 00000000..1a709efd --- /dev/null +++ b/debug/benchmark_cpp_vs_py_sampler_probit.py @@ -0,0 +1,144 @@ +"""Benchmark: C++ sampler loop vs. Python sampler loop – probit BART. + +Compares runtime, Brier score, and RMSE-to-truth (vs. pnorm(f_X)) across +run_cpp=True / False in BARTModel.sample(). + +Usage: + source venv/bin/activate # or: conda activate stochtree-book + python debug/benchmark_cpp_vs_py_sampler_probit.py +""" + +import time +import numpy as np +from scipy.stats import norm +from stochtree import BARTModel, OutcomeModel + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +rng = np.random.default_rng(1234) + +n = 2000 +p = 10 +X = rng.uniform(size=(n, p)) + +# Latent mean on the probit (standard-normal) scale – same step function as +# the continuous benchmark, keeping values well within identifiable range. +f_X = ( + np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -7.5, 0) + + np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -2.5, 0) + + np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 2.5, 0) + + np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 7.5, 0) +) +p_X = norm.cdf(f_X) # true P(Y = 1 | X) +y = rng.binomial(1, p_X).astype(float) # observed binary outcome + +test_frac = 0.2 +n_test = round(test_frac * n) +n_train = n - n_test +test_inds = rng.choice(n, size=n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) + +X_train, X_test = X[train_inds], X[test_inds] +y_train, y_test = y[train_inds], y[test_inds] +p_test = p_X[test_inds] + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr = 10 +num_mcmc = 100 +num_trees = 200 +n_reps = 3 + +print( + f"n_train={n_train} n_test={n_test} p={p} " + f"num_trees={num_trees} num_gfr={num_gfr} num_mcmc={num_mcmc} reps={n_reps}\n" +) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + metrics +# --------------------------------------------------------------------------- +def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: + m = BARTModel() + t0 = time.perf_counter() + m.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=0, + num_mcmc=num_mcmc, + mean_forest_params={"num_trees": num_trees}, + general_params={ + "random_seed": seed, + "outcome_model": OutcomeModel(outcome="binary", link="probit"), + "sample_sigma2_global": False, + }, + run_cpp=run_cpp, + ) + elapsed = time.perf_counter() - t0 + + # Posterior-mean predicted probability on the test set + preds = m.predict(X=X_test, scale="probability") + p_hat = preds["y_hat"].mean(axis=1) # (n_test,) + + brier = float(np.mean((p_hat - y_test) ** 2)) # Brier score + rmse_p = float(np.sqrt(np.mean((p_hat - p_test) ** 2))) # RMSE vs pnorm(f_X) + + return {"elapsed": elapsed, "brier": brier, "rmse_p": rmse_p} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds = [1000 + i for i in range(1, n_reps + 1)] + +results_cpp = [] +results_py = [] + +print("Running C++ sampler (run_cpp=True)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_cpp.append(run_once(run_cpp=True, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + +print("\nRunning Python sampler (run_cpp=False)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_py.append(run_once(run_cpp=False, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +def summarise(results: list) -> dict: + elapsed = [r["elapsed"] for r in results] + brier = [r["brier"] for r in results] + rmse_p = [r["rmse_p"] for r in results] + return { + "elapsed_mean": np.mean(elapsed), + "elapsed_sd": np.std(elapsed, ddof=1), + "brier_mean": np.mean(brier), + "rmse_p_mean": np.mean(rmse_p), + } + +s_cpp = summarise(results_cpp) +s_py = summarise(results_py) +rows = [("cpp (run_cpp=True)", s_cpp), ("py (run_cpp=False)", s_py)] + +print("\n--- Results ---") +print( + f"{'Sampler':<22} {'Time (s)':>10} {'SD':>10} " + f"{'Brier':>12} {'RMSE (vs pnorm)':>16}" +) +print("-" * 76) +for label, s in rows: + print( + f"{label:<22} {s['elapsed_mean']:>10.3f} {s['elapsed_sd']:>10.3f}" + f" {s['brier_mean']:>12.4f} {s['rmse_p_mean']:>16.4f}" + ) + +speedup = s_py["elapsed_mean"] / s_cpp["elapsed_mean"] +print(f"\nSpeedup (py / cpp): {speedup:.2f}x") +print( + f"Brier delta (cpp - py): {s_cpp['brier_mean'] - s_py['brier_mean']:.4f}\n" + f"RMSE-p delta (cpp - py): {s_cpp['rmse_p_mean'] - s_py['rmse_p_mean']:.4f}" +) diff --git a/debug/benchmark_cpp_vs_r_sampler_probit.R b/debug/benchmark_cpp_vs_r_sampler_probit.R new file mode 100644 index 00000000..153e5b64 --- /dev/null +++ b/debug/benchmark_cpp_vs_r_sampler_probit.R @@ -0,0 +1,179 @@ +## Benchmark: C++ sampler loop vs. R sampler loop – probit BART +## Compares runtime, Brier score, and RMSE-to-truth (vs. pnorm(f_X)) across +## run_cpp = TRUE / FALSE in bart(). +## +## Usage: Rscript debug/benchmark_cpp_vs_r_sampler_probit.R +## or source() from an interactive session after devtools::load_all('.') +library(stochtree) + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +set.seed(1234) + +n <- 2000 +p <- 10 +X <- matrix(runif(n * p), ncol = p) + +# Latent mean on the probit (standard-normal) scale – same step function as +# the continuous benchmark, keeping values well within identifiable range. +f_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * + (-7.5) + + ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (-2.5) + + ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (2.5) + + ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (7.5)) +p_X <- pnorm(f_X) # true P(Y = 1 | X) +y <- rbinom(n, 1L, p_X) # observed binary outcome + +test_frac <- 0.2 +n_test <- round(test_frac * n) +n_train <- n - n_test +test_inds <- sort(sample(seq_len(n), n_test, replace = FALSE)) +train_inds <- setdiff(seq_len(n), test_inds) + +X_train <- X[train_inds, ] +X_test <- X[test_inds, ] +y_train <- y[train_inds] +y_test <- y[test_inds] +p_test <- p_X[test_inds] + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr <- 10 +num_mcmc <- 100 +num_trees <- 200 +n_reps <- 3 + +cat(sprintf( + "n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_mcmc=%d reps=%d\n\n", + n_train, + n_test, + p, + num_trees, + num_gfr, + num_mcmc, + n_reps +)) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + metrics +# --------------------------------------------------------------------------- +run_once <- function(run_cpp, num_gfr, num_mcmc, seed = -1) { + t0 <- proc.time() + m <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = num_mcmc, + mean_forest_params = list(num_trees = num_trees), + general_params = list( + random_seed = seed, + outcome_model = OutcomeModel(outcome = "binary", link = "probit"), + sample_sigma2_global = FALSE + ), + run_cpp = run_cpp + ) + elapsed <- (proc.time() - t0)[["elapsed"]] + + # Posterior-mean predicted probability on the test set + p_hat_mat <- predict( + m, + X = X_test, + type = "posterior", + terms = "y_hat", + scale = "probability" + ) + if (is.null(dim(p_hat_mat))) { + p_hat_mat <- matrix(p_hat_mat, ncol = 1) + } + p_hat <- rowMeans(p_hat_mat) + + brier <- mean((p_hat - y_test)^2) # Brier score (lower is better) + rmse_p <- sqrt(mean((p_hat - p_test)^2)) # RMSE vs. true pnorm(f_X) + + list(elapsed = elapsed, brier = brier, rmse_p = rmse_p) +} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds <- 1000 + seq_len(n_reps) + +results_cpp <- vector("list", n_reps) +results_r <- vector("list", n_reps) + +cat("Running C++ sampler (run_cpp = TRUE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_cpp[[i]] <- run_once( + run_cpp = TRUE, + num_gfr = num_gfr, + num_mcmc = num_mcmc, + seed = seeds[i] + ) +} + +cat("\nRunning R sampler (run_cpp = FALSE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_r[[i]] <- run_once( + run_cpp = FALSE, + num_gfr = num_gfr, + num_mcmc = num_mcmc, + seed = seeds[i] + ) +} + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +summarise <- function(results, label) { + elapsed <- sapply(results, `[[`, "elapsed") + brier <- sapply(results, `[[`, "brier") + rmse_p <- sapply(results, `[[`, "rmse_p") + data.frame( + sampler = label, + elapsed_mean = mean(elapsed), + elapsed_sd = sd(elapsed), + brier_mean = mean(brier), + rmse_p_mean = mean(rmse_p), + row.names = NULL + ) +} + +res <- rbind( + summarise(results_cpp, "cpp (run_cpp=TRUE)"), + summarise(results_r, "R (run_cpp=FALSE)") +) + +cat("\n--- Results ---\n") +cat(sprintf( + "%-22s %10s %10s %12s %16s\n", + "Sampler", + "Time (s)", + "SD", + "Brier", + "RMSE (vs pnorm)" +)) +cat(strrep("-", 76), "\n") +for (i in seq_len(nrow(res))) { + cat(sprintf( + "%-22s %10.3f %10.3f %12.4f %16.4f\n", + res$sampler[i], + res$elapsed_mean[i], + res$elapsed_sd[i], + res$brier_mean[i], + res$rmse_p_mean[i] + )) +} + +speedup <- res$elapsed_mean[2] / res$elapsed_mean[1] +cat(sprintf("\nSpeedup (R / C++): %.2fx\n", speedup)) +cat(sprintf( + "Brier delta (cpp - R): %.4f\nRMSE-p delta (cpp - R): %.4f\n", + res$brier_mean[1] - res$brier_mean[2], + res$rmse_p_mean[1] - res$rmse_p_mean[2] +)) diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 31f75a23..7156eb06 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -174,7 +174,7 @@ cpp11::writable::list convert_bart_results_to_list(StochTree::BARTSamples& bart_ : R_NilValue; output.push_back(cpp11::named_arg("leaf_scale_samples") = leaf_scale_sexp); - // Sample metadata + // Metadata about the model that was sampled double y_bar_sexp = bart_samples.y_bar; output.push_back(cpp11::named_arg("y_bar") = y_bar_sexp); double y_std_sexp = bart_samples.y_std; @@ -185,10 +185,19 @@ cpp11::writable::list convert_bart_results_to_list(StochTree::BARTSamples& bart_ output.push_back(cpp11::named_arg("num_train") = num_train_sexp); int num_test_sexp = bart_samples.num_test; output.push_back(cpp11::named_arg("num_test") = num_test_sexp); - return output; } +void add_config_to_result_list(cpp11::writable::list& result, StochTree::BARTConfig& config) { + // Unpack more metadata about the model that was sampled + result.push_back(cpp11::named_arg("sigma2_global_init") = config.sigma2_global_init); + result.push_back(cpp11::named_arg("sigma2_mean_init") = config.sigma2_mean_init); + result.push_back(cpp11::named_arg("b_sigma2_mean") = config.b_sigma2_mean); + result.push_back(cpp11::named_arg("shape_variance_forest") = config.shape_variance_forest); + result.push_back(cpp11::named_arg("scale_variance_forest") = config.scale_variance_forest); + return; +} + [[cpp11::register]] cpp11::writable::list bart_sample_cpp( cpp11::sexp X_train, @@ -265,5 +274,8 @@ cpp11::writable::list bart_sample_cpp( // Unprotect protected R objects UNPROTECT(protect_count); - return convert_bart_results_to_list(results_raw); + // Unpack outputs + cpp11::writable::list output_list = convert_bart_results_to_list(results_raw); + add_config_to_result_list(output_list, config); + return output_list; } diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index f670240e..27ac3a52 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -51,33 +51,24 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } double y_var = M2 / data_.n_train; - // Compute outcome location and scale for standardization - if (config_.link_function == LinkFunction::Probit) { - samples.y_std = 1.0; - samples.y_bar = norm_inv_cdf(y_mean); - } else { - if (config_.standardize_outcome) { - samples.y_bar = y_mean; - samples.y_std = std::sqrt(y_var); - } else { - samples.y_bar = 0.0; + // Standardization and calibration for mean forests + double init_val_mean; + if (config_.num_trees_mean > 0) { + if (config_.link_function == LinkFunction::Probit) { samples.y_std = 1.0; + samples.y_bar = norm_inv_cdf(y_mean); + init_val_mean = 0.0; + } else { + if (config_.standardize_outcome) { + samples.y_bar = y_mean; + samples.y_std = std::sqrt(y_var); + init_val_mean = 0.0; + } else { + samples.y_bar = 0.0; + samples.y_std = 1.0; + init_val_mean = y_mean; + } } - } - - // Standardize partial residuals in place; these are updated in each iteration but initialized to standardized outcomes - for (int i = 0; i < data_.n_train; i++) residual_->GetData()[i] = (data_.y_train[i] - samples.y_bar) / samples.y_std; - - // Initialize mean forest state (if present) - if (config_.num_trees_mean > 0) { - mean_forest_ = std::make_unique(config_.num_trees_mean, config_.leaf_dim_mean, config_.leaf_constant_mean, config_.exponentiated_leaf_mean); - samples.mean_forests = std::make_unique(config_.num_trees_mean, config_.leaf_dim_mean, config_.leaf_constant_mean, config_.exponentiated_leaf_mean); - mean_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_mean, data_.n_train); - tree_prior_mean_ = std::make_unique(config_.alpha_mean, config_.beta_mean, config_.min_samples_leaf_mean, config_.max_depth_mean); - mean_forest_->SetLeafValue(0.0); - UpdateResidualEntireForest(*mean_forest_tracker_, *forest_dataset_, *residual_, mean_forest_.get(), !config_.leaf_constant_mean, std::minus()); - mean_forest_tracker_->UpdatePredictions(mean_forest_.get(), *forest_dataset_.get()); - has_mean_forest_ = true; if (config_.sigma2_mean_init < 0.0) { if (config_.link_function == LinkFunction::Probit) { config_.sigma2_mean_init = 1.0 / config_.num_trees_mean; @@ -96,15 +87,10 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } } - // Initialize variance forest state (if present) + // Calibration for variance forests + double init_val_variance; if (config_.num_trees_variance > 0) { - variance_forest_ = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); - samples.variance_forests = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); - variance_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_variance, data_.n_train); - tree_prior_variance_ = std::make_unique(config_.alpha_variance, config_.beta_variance, config_.min_samples_leaf_variance, config_.max_depth_variance); - variance_forest_->SetLeafValue(1.0 / config_.num_trees_variance); - variance_forest_tracker_->UpdatePredictions(variance_forest_.get(), *forest_dataset_.get()); - has_variance_forest_ = true; + // NOTE: calibration only works for standardized outcomes if (config_.shape_variance_forest <= 0.0 || config_.scale_variance_forest <= 0.0) { if (config_.leaf_prior_calibration_param <= 0.0) { config_.leaf_prior_calibration_param = 1.5; @@ -116,6 +102,41 @@ void BARTSampler::InitializeState(BARTSamples& samples) { config_.scale_variance_forest = config_.num_trees_variance / (config_.leaf_prior_calibration_param * config_.leaf_prior_calibration_param); } } + if (config_.standardize_outcome) { + init_val_variance = 1.0; + } else { + init_val_variance = y_var; + } + } + + // Standardize partial residuals in place; these are updated in each iteration but initialized to standardized outcomes + // Works for: + // 1. Standardized outcomes (since y_bar = mean(y) and y_std = sd(y)) + // 2. Non-standardized outcomes (since y_bar = 0 and y_std = 1, so this just transfers y_train as-is) + // 3. Probit link (since y_bar = norm_inv_cdf(mean(y)) and y_std = 1) + for (int i = 0; i < data_.n_train; i++) residual_->GetData()[i] = (data_.y_train[i] - samples.y_bar) / samples.y_std; + + // Initialize mean forest state (if present) + if (config_.num_trees_mean > 0) { + mean_forest_ = std::make_unique(config_.num_trees_mean, config_.leaf_dim_mean, config_.leaf_constant_mean, config_.exponentiated_leaf_mean); + samples.mean_forests = std::make_unique(config_.num_trees_mean, config_.leaf_dim_mean, config_.leaf_constant_mean, config_.exponentiated_leaf_mean); + mean_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_mean, data_.n_train); + tree_prior_mean_ = std::make_unique(config_.alpha_mean, config_.beta_mean, config_.min_samples_leaf_mean, config_.max_depth_mean); + mean_forest_->SetLeafValue(init_val_mean / config_.num_trees_mean); + UpdateResidualEntireForest(*mean_forest_tracker_, *forest_dataset_, *residual_, mean_forest_.get(), !config_.leaf_constant_mean, std::minus()); + mean_forest_tracker_->UpdatePredictions(mean_forest_.get(), *forest_dataset_.get()); + has_mean_forest_ = true; + } + + // Initialize variance forest state (if present) + if (config_.num_trees_variance > 0) { + variance_forest_ = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); + samples.variance_forests = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); + variance_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_variance, data_.n_train); + tree_prior_variance_ = std::make_unique(config_.alpha_variance, config_.beta_variance, config_.min_samples_leaf_variance, config_.max_depth_variance); + variance_forest_->SetLeafValue(init_val_variance / config_.num_trees_variance); + variance_forest_tracker_->UpdatePredictions(variance_forest_.get(), *forest_dataset_.get()); + has_variance_forest_ = true; } // Global error variance model @@ -142,12 +163,7 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } void BARTSampler::run_gfr(BARTSamples& samples, int num_gfr, bool keep_gfr) { - // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance - std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); - std::unique_ptr variance_leaf_model_ptr = std::make_unique(config_.shape_variance_forest, config_.scale_variance_forest); - for (int i = 0; i < num_gfr; i++) { - RunOneIteration(samples, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/true, /*keep_sample=*/keep_gfr); - } + // Reserve space for GFR predictions if they are to be retained if (keep_gfr) { if (has_mean_forest_) { samples.mean_forest_predictions_train.reserve(data_.n_train * num_gfr); @@ -156,12 +172,17 @@ void BARTSampler::run_gfr(BARTSamples& samples, int num_gfr, bool keep_gfr) { samples.variance_forest_predictions_train.reserve(data_.n_train * num_gfr); } } -} -void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, int num_mcmc) { + // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); std::unique_ptr variance_leaf_model_ptr = std::make_unique(config_.shape_variance_forest, config_.scale_variance_forest); - bool keep_forest = false; + for (int i = 0; i < num_gfr; i++) { + RunOneIteration(samples, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/true, /*keep_sample=*/keep_gfr); + } +} + +void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, int num_mcmc) { + // Reserve space for MCMC predictions if they are to be retained if (has_mean_forest_) { samples.mean_forest_predictions_train.reserve(data_.n_train * num_mcmc); if (has_test_) { @@ -171,9 +192,14 @@ void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, if (has_variance_forest_) { samples.variance_forest_predictions_train.reserve(data_.n_train * num_mcmc); if (has_test_) { - samples.variance_forest_predictions_train.reserve(data_.n_test * num_mcmc); + samples.variance_forest_predictions_test.reserve(data_.n_test * num_mcmc); } } + + // Create leaf models and pass them to the RunOneIteration function; these are updated in place and will reflect the current state of the leaf scale parameters (if they are being sampled) + std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); + std::unique_ptr variance_leaf_model_ptr = std::make_unique(config_.shape_variance_forest, config_.scale_variance_forest); + bool keep_forest = false; for (int i = 0; i < num_burnin + keep_every * num_mcmc; i++) { if (i >= num_burnin && (i - num_burnin) % keep_every == 0) keep_forest = true; @@ -184,12 +210,18 @@ void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, } void BARTSampler::postprocess_samples(BARTSamples& samples) { - if (has_mean_forest_) { - if (has_test_) { + // Unpack test set predictions for mean and variance forest + if (has_test_) { + if (has_mean_forest_) { std::vector predictions = samples.mean_forests->Predict(*forest_dataset_test_); samples.mean_forest_predictions_test.insert(samples.mean_forest_predictions_test.end(), predictions.data(), predictions.data() + predictions.size()); } + if (has_variance_forest_) { + std::vector predictions = samples.variance_forests->Predict(*forest_dataset_test_); + samples.variance_forest_predictions_test.insert(samples.variance_forest_predictions_test.end(), + predictions.data(), predictions.data() + predictions.size()); + } } } From f1a50fc3ac4ca1150be885b22e90e138849d51c7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 14 Apr 2026 16:58:59 -0500 Subject: [PATCH 52/64] Updated python interface to ensure probit and predict work --- src/py_stochtree.cpp | 9 ++++++++- stochtree/bart.py | 17 +++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 21d48913..836fd237 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2397,6 +2397,11 @@ inline py::dict convert_bart_results_to_dict( return output; } +void add_config_to_result_dict(py::dict& result, StochTree::BARTConfig& config) { + // Unpack more metadata about the model that was sampled + result["sigma2_init"] = config.sigma2_global_init; +} + py::dict bart_sample_cpp( py::object X_train, py::object y_train, @@ -2438,7 +2443,9 @@ py::dict bart_sample_cpp( bart_sampler.postprocess_samples(bart_results_raw); // Convert results to Python dictionary - return convert_bart_results_to_dict(bart_results_raw, bart_config); + py::dict bart_results = convert_bart_results_to_dict(bart_results_raw, bart_config); + add_config_to_result_dict(bart_results, bart_config); + return bart_results; } py::array_t cppComputeForestContainerLeafIndices(ForestContainerCpp& forest_container, py::array_t& covariates, py::array_t& forest_nums) { diff --git a/stochtree/bart.py b/stochtree/bart.py index 91330d76..e79c7374 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1185,6 +1185,13 @@ def sample( config_input = bart_config ) + # Store high level model metadata from C++ results + self.num_gfr = num_gfr + self.num_burnin = num_burnin + self.keep_every = keep_every + self.num_mcmc = num_mcmc + self.num_chains = num_chains + # Unpack standardization params computed by C++ sampler self.y_bar = bart_results["y_bar"] self.y_std = bart_results["y_std"] @@ -1207,15 +1214,17 @@ def sample( self.forest_container_variance = ForestContainer(num_trees_variance, 1, True, True) self.forest_container_variance.forest_container_cpp = bart_results["forest_container_variance"] variance_forest_preds_train = bart_results["variance_forest_predictions_train"].reshape(self.n_train, bart_results["num_samples"], order="F") - self.variance_forest_preds_train = variance_forest_preds_train * self.y_std * self.y_std + self.sigma2_x_train = variance_forest_preds_train * self.y_std * self.y_std if self.has_test: variance_forest_preds_test = bart_results["variance_forest_predictions_test"].reshape(self.n_test, bart_results["num_samples"], order="F") - self.variance_forest_preds_test = variance_forest_preds_test * self.y_std * self.y_std + self.sigma2_x_test = variance_forest_preds_test * self.y_std * self.y_std # Unpack parameter samples - if sample_sigma2_global: + self.sample_sigma2_global = sample_sigma2_global + self.sample_sigma2_leaf = sample_sigma2_leaf + if self.sample_sigma2_global: self.global_var_samples = bart_results["global_var_samples"] * self.y_std * self.y_std - if sample_sigma2_leaf: + if self.sample_sigma2_leaf: self.leaf_scale_samples = bart_results["leaf_scale_samples"] # Unpack other model metadata From 4e366d29fe2b3183545c2f49e785f8a8fcc3deaa Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Apr 2026 11:44:11 -0500 Subject: [PATCH 53/64] Unpack more fields from the python C++ interface --- stochtree/bart.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stochtree/bart.py b/stochtree/bart.py index e79c7374..bd50b985 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1191,10 +1191,13 @@ def sample( self.keep_every = keep_every self.num_mcmc = num_mcmc self.num_chains = num_chains + self.sample_sigma2_global = sample_sigma2_global + self.sample_sigma2_leaf = sample_sigma2_leaf # Unpack standardization params computed by C++ sampler self.y_bar = bart_results["y_bar"] self.y_std = bart_results["y_std"] + self.sigma2_init = bart_results["sigma2_init"] # Unpack mean forest results self.forest_container_mean = ( From 1dbcafc92f73ddbb313feea8c75a6a2bc5f797fa Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Apr 2026 13:18:01 -0500 Subject: [PATCH 54/64] Extract calibrated fields from the python C++ BART interface --- src/py_stochtree.cpp | 4 ++++ stochtree/bart.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 836fd237..72995b34 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2400,6 +2400,10 @@ inline py::dict convert_bart_results_to_dict( void add_config_to_result_dict(py::dict& result, StochTree::BARTConfig& config) { // Unpack more metadata about the model that was sampled result["sigma2_init"] = config.sigma2_global_init; + result["sigma2_mean_init"] = config.sigma2_mean_init; + result["b_sigma2_mean"] = config.b_sigma2_mean; + result["shape_variance_forest"] = config.shape_variance_forest; + result["scale_variance_forest"] = config.scale_variance_forest; } py::dict bart_sample_cpp( diff --git a/stochtree/bart.py b/stochtree/bart.py index bd50b985..67eee8c5 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1198,6 +1198,10 @@ def sample( self.y_bar = bart_results["y_bar"] self.y_std = bart_results["y_std"] self.sigma2_init = bart_results["sigma2_init"] + self.sigma2_leaf_init = bart_results["sigma2_leaf_init"] if self.include_mean_forest else None + self.b_leaf = bart_results["b_sigma2_mean"] if self.include_mean_forest else None + self.shape_variance_forest = bart_results["shape_variance_forest"] if self.include_variance_forest else None + self.scale_variance_forest = bart_results["scale_variance_forest"] if self.include_variance_forest else None # Unpack mean forest results self.forest_container_mean = ( @@ -1434,6 +1438,9 @@ def sample( a_forest = 1.0 if not b_forest: b_forest = 1.0 + self.shape_variance_forest = a_forest + self.scale_variance_forest = b_forest + self.sigma2_leaf_init = bart_results["sigma2_leaf_init"] if self.include_mean_forest else None # Set up random effects structures if self.has_rfx: From 00473eeae9fa09332bc856b7e29c976b185bd8d2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Apr 2026 14:39:44 -0500 Subject: [PATCH 55/64] Non-working model type dispatch refactor and cloglog implementation --- R/bart.R | 29 ++- debug/bart_debug.cpp | 1 + debug/benchmark_cpp_vs_py_sampler_cloglog.py | 147 +++++++++++++++ debug/benchmark_cpp_vs_r_sampler_cloglog.R | 183 +++++++++++++++++++ include/stochtree/bart.h | 17 ++ include/stochtree/bart_sampler.h | 126 ++++++++++++- src/R_bart.cpp | 10 + src/bart_sampler.cpp | 142 +++++++++----- src/py_stochtree.cpp | 15 ++ stochtree/bart.py | 66 ++++--- 10 files changed, 667 insertions(+), 69 deletions(-) create mode 100644 debug/benchmark_cpp_vs_py_sampler_cloglog.py create mode 100644 debug/benchmark_cpp_vs_r_sampler_cloglog.R diff --git a/R/bart.R b/R/bart.R index 67335feb..e557e557 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1090,6 +1090,11 @@ bart <- function( leaf_regression = FALSE } + cloglog_num_categories <- ifelse( + link_is_cloglog, + max(y_train - min(y_train)) + 1, + 0 + ) model_params_r <- list( "a_global" = a_global, "b_global" = b_global, @@ -1120,11 +1125,7 @@ bart <- function( "include_variance_forest" = include_variance_forest, "outcome_model" = outcome_model, "probit_outcome_model" = probit_outcome_model, - "cloglog_num_categories" = ifelse( - link_is_cloglog, - max(y_train - min(y_train)) + 1, - 0 - ), + "cloglog_num_categories" = cloglog_num_categories, "rfx_model_spec" = rfx_model_spec ) @@ -1164,6 +1165,11 @@ bart <- function( "b_sigma2_mean" = b_leaf, "sigma2_mean_init" = sigma2_leaf_init, "sample_sigma2_leaf_mean" = sample_sigma2_leaf, + "mean_leaf_model_type" = leaf_model_mean_forest, + "num_classes_cloglog" = cloglog_num_categories, + "cloglog_leaf_prior_shape" = cloglog_leaf_prior_shape, + "cloglog_leaf_prior_scale" = cloglog_leaf_prior_scale, + "cloglog_cutpoint_0" = 0, "num_trees_variance" = num_trees_variance, "leaf_prior_calibration_param" = a_0, "shape_variance_forest" = a_forest, @@ -1339,6 +1345,19 @@ bart <- function( result[["variance_forests"]] <- variance_forests_r } + has_cloglog_cutpoint_samples <- !is.null( + bart_results[['cloglog_cutpoint_samples']] + ) + if (has_cloglog_cutpoint_samples) { + dim(bart_results[['cloglog_cutpoint_samples']]) <- c( + cloglog_num_categories - 1, + bart_results[["num_samples"]] + ) + result[["cloglog_cutpoint_samples"]] <- t(bart_results[[ + "cloglog_cutpoint_samples" + ]]) + } + class(result) <- "bartmodel" } else { # Set a function-scoped RNG if user provided a random seed diff --git a/debug/bart_debug.cpp b/debug/bart_debug.cpp index 1223179c..8da1bb71 100644 --- a/debug/bart_debug.cpp +++ b/debug/bart_debug.cpp @@ -155,6 +155,7 @@ static void run_scenario_1(int n, int n_test, int p, int num_trees, int num_gfr, StochTree::BARTConfig config; config.num_trees_mean = num_trees; config.random_seed = seed; + config.mean_leaf_model_type = StochTree::MeanLeafModelType::GaussianConstant; config.link_function = StochTree::LinkFunction::Probit; config.sample_sigma2_global = false; config.var_weights_mean = std::vector(p, 1.0 / p); diff --git a/debug/benchmark_cpp_vs_py_sampler_cloglog.py b/debug/benchmark_cpp_vs_py_sampler_cloglog.py new file mode 100644 index 00000000..67e63f09 --- /dev/null +++ b/debug/benchmark_cpp_vs_py_sampler_cloglog.py @@ -0,0 +1,147 @@ +"""Benchmark: C++ sampler loop vs. Python sampler loop – cloglog BART. + +Compares runtime, Brier score, and RMSE-to-truth (vs. true P(Y=1|X)) across +run_cpp=True / False in BARTModel.sample(). + +DGP uses the cloglog link: P(Y=1|X) = 1 - exp(-exp(f(X))). +The step function for f(X) is kept in the range [-2, 1] so that the implied +probabilities span roughly 0.13 to 0.93 and are well-identified. + +Usage: + conda activate stochtree-book # or: source venv/bin/activate + python debug/benchmark_cpp_vs_py_sampler_cloglog.py +""" + +import time +import numpy as np +from stochtree import BARTModel, OutcomeModel + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +rng = np.random.default_rng(1234) + +n = 2000 +p = 10 +X = rng.uniform(size=(n, p)) + +# Latent mean on the cloglog (log-log) scale. +# P(Y=1|X) = 1 - exp(-exp(f_X)); values chosen so probabilities are moderate. +f_X = ( + np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -2.0, 0) + + np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -0.5, 0) + + np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 0.5, 0) + + np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 1.0, 0) +) +p_X = 1.0 - np.exp(-np.exp(f_X)) # true P(Y = 1 | X) +y = rng.binomial(1, p_X).astype(float) # observed binary outcome + +test_frac = 0.2 +n_test = round(test_frac * n) +n_train = n - n_test +test_inds = rng.choice(n, size=n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) + +X_train, X_test = X[train_inds], X[test_inds] +y_train, y_test = y[train_inds], y[test_inds] +p_test = p_X[test_inds] + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr = 10 +num_mcmc = 100 +num_trees = 200 +n_reps = 3 + +print( + f"n_train={n_train} n_test={n_test} p={p} " + f"num_trees={num_trees} num_gfr={num_gfr} num_mcmc={num_mcmc} reps={n_reps}\n" +) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + metrics +# --------------------------------------------------------------------------- +def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: + m = BARTModel() + t0 = time.perf_counter() + m.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=0, + num_mcmc=num_mcmc, + mean_forest_params={"num_trees": num_trees}, + general_params={ + "random_seed": seed, + "outcome_model": OutcomeModel(outcome="binary", link="cloglog"), + "sample_sigma2_global": False, + }, + run_cpp=run_cpp, + ) + elapsed = time.perf_counter() - t0 + + # Posterior-mean predicted probability on the test set + preds = m.predict(X=X_test, scale="probability") + p_hat = preds["y_hat"].mean(axis=1) # (n_test,) + + brier = float(np.mean((p_hat - y_test) ** 2)) + rmse_p = float(np.sqrt(np.mean((p_hat - p_test) ** 2))) + + return {"elapsed": elapsed, "brier": brier, "rmse_p": rmse_p} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds = [1000 + i for i in range(1, n_reps + 1)] + +results_cpp = [] +results_py = [] + +print("Running C++ sampler (run_cpp=True)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_cpp.append(run_once(run_cpp=True, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + +print("\nRunning Python sampler (run_cpp=False)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_py.append(run_once(run_cpp=False, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +def summarise(results: list) -> dict: + elapsed = [r["elapsed"] for r in results] + brier = [r["brier"] for r in results] + rmse_p = [r["rmse_p"] for r in results] + return { + "elapsed_mean": np.mean(elapsed), + "elapsed_sd": np.std(elapsed, ddof=1), + "brier_mean": np.mean(brier), + "rmse_p_mean": np.mean(rmse_p), + } + +s_cpp = summarise(results_cpp) +s_py = summarise(results_py) +rows = [("cpp (run_cpp=True)", s_cpp), ("py (run_cpp=False)", s_py)] + +print("\n--- Results ---") +print( + f"{'Sampler':<22} {'Time (s)':>10} {'SD':>10} " + f"{'Brier':>12} {'RMSE (vs truth)':>15}" +) +print("-" * 75) +for label, s in rows: + print( + f"{label:<22} {s['elapsed_mean']:>10.3f} {s['elapsed_sd']:>10.3f}" + f" {s['brier_mean']:>12.4f} {s['rmse_p_mean']:>15.4f}" + ) + +speedup = s_py["elapsed_mean"] / s_cpp["elapsed_mean"] +print(f"\nSpeedup (py / cpp): {speedup:.2f}x") +print( + f"Brier delta (cpp - py): {s_cpp['brier_mean'] - s_py['brier_mean']:.4f}\n" + f"RMSE-p delta (cpp - py): {s_cpp['rmse_p_mean'] - s_py['rmse_p_mean']:.4f}" +) diff --git a/debug/benchmark_cpp_vs_r_sampler_cloglog.R b/debug/benchmark_cpp_vs_r_sampler_cloglog.R new file mode 100644 index 00000000..47613dbb --- /dev/null +++ b/debug/benchmark_cpp_vs_r_sampler_cloglog.R @@ -0,0 +1,183 @@ +## Benchmark: C++ sampler loop vs. R sampler loop – cloglog BART +## Compares runtime, Brier score, and RMSE-to-truth (vs. true P(Y=1|X)) across +## run_cpp = TRUE / FALSE in bart(). +## +## DGP uses the cloglog link: P(Y=1|X) = 1 - exp(-exp(f(X))). +## The step function for f(X) is kept in the range [-2, 1] so that the implied +## probabilities span roughly 0.13 to 0.93 and are well-identified. +## +## Usage: Rscript debug/benchmark_cpp_vs_r_sampler_cloglog.R +## or source() from an interactive session after devtools::load_all('.') +library(stochtree) + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +set.seed(1234) + +n <- 2000 +p <- 10 +X <- matrix(runif(n * p), ncol = p) + +# Latent mean on the cloglog (log-log) scale. +# P(Y=1|X) = 1 - exp(-exp(f_X)); values chosen so probabilities are moderate. +f_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * + (-2.0) + + ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (-0.5) + + ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (0.5) + + ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (1.0)) +p_X <- 1 - exp(-exp(f_X)) # true P(Y = 1 | X) +y <- rbinom(n, 1L, p_X) # observed binary outcome + +test_frac <- 0.2 +n_test <- round(test_frac * n) +n_train <- n - n_test +test_inds <- sort(sample(seq_len(n), n_test, replace = FALSE)) +train_inds <- setdiff(seq_len(n), test_inds) + +X_train <- X[train_inds, ] +X_test <- X[test_inds, ] +y_train <- y[train_inds] +y_test <- y[test_inds] +p_test <- p_X[test_inds] + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr <- 10 +num_mcmc <- 100 +num_trees <- 200 +n_reps <- 3 + +cat(sprintf( + "n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_mcmc=%d reps=%d\n\n", + n_train, + n_test, + p, + num_trees, + num_gfr, + num_mcmc, + n_reps +)) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + metrics +# --------------------------------------------------------------------------- +run_once <- function(run_cpp, num_gfr, num_mcmc, seed = -1) { + t0 <- proc.time() + m <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = num_mcmc, + mean_forest_params = list(num_trees = num_trees), + general_params = list( + random_seed = seed, + outcome_model = OutcomeModel(outcome = "binary", link = "cloglog"), + sample_sigma2_global = FALSE + ), + run_cpp = run_cpp + ) + elapsed <- (proc.time() - t0)[["elapsed"]] + + # Posterior-mean predicted probability on the test set + p_hat_mat <- predict( + m, + X = X_test, + type = "posterior", + terms = "y_hat", + scale = "probability" + ) + if (is.null(dim(p_hat_mat))) { + p_hat_mat <- matrix(p_hat_mat, ncol = 1) + } + p_hat <- rowMeans(p_hat_mat) + + brier <- mean((p_hat - y_test)^2) # Brier score (lower is better) + rmse_p <- sqrt(mean((p_hat - p_test)^2)) # RMSE vs. true cloglog probabilities + + list(elapsed = elapsed, brier = brier, rmse_p = rmse_p) +} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds <- 1000 + seq_len(n_reps) + +results_cpp <- vector("list", n_reps) +results_r <- vector("list", n_reps) + +cat("Running C++ sampler (run_cpp = TRUE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_cpp[[i]] <- run_once( + run_cpp = TRUE, + num_gfr = num_gfr, + num_mcmc = num_mcmc, + seed = seeds[i] + ) +} + +cat("\nRunning R sampler (run_cpp = FALSE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_r[[i]] <- run_once( + run_cpp = FALSE, + num_gfr = num_gfr, + num_mcmc = num_mcmc, + seed = seeds[i] + ) +} + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +summarise <- function(results, label) { + elapsed <- sapply(results, `[[`, "elapsed") + brier <- sapply(results, `[[`, "brier") + rmse_p <- sapply(results, `[[`, "rmse_p") + data.frame( + sampler = label, + elapsed_mean = mean(elapsed), + elapsed_sd = sd(elapsed), + brier_mean = mean(brier), + rmse_p_mean = mean(rmse_p), + row.names = NULL + ) +} + +res <- rbind( + summarise(results_cpp, "cpp (run_cpp=TRUE)"), + summarise(results_r, "R (run_cpp=FALSE)") +) + +cat("\n--- Results ---\n") +cat(sprintf( + "%-22s %10s %10s %12s %15s\n", + "Sampler", + "Time (s)", + "SD", + "Brier", + "RMSE (vs truth)" +)) +cat(strrep("-", 75), "\n") +for (i in seq_len(nrow(res))) { + cat(sprintf( + "%-22s %10.3f %10.3f %12.4f %15.4f\n", + res$sampler[i], + res$elapsed_mean[i], + res$elapsed_sd[i], + res$brier_mean[i], + res$rmse_p_mean[i] + )) +} + +speedup <- res$elapsed_mean[2] / res$elapsed_mean[1] +cat(sprintf("\nSpeedup (R / C++): %.2fx\n", speedup)) +cat(sprintf( + "Brier delta (cpp - R): %.4f\nRMSE-p delta (cpp - R): %.4f\n", + res$brier_mean[1] - res$brier_mean[2], + res$rmse_p_mean[1] - res$rmse_p_mean[2] +)) diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index 66fc9c66..fd14f0f4 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -7,6 +7,7 @@ #include #include +#include "stochtree/leaf_model.h" #include #include @@ -24,6 +25,14 @@ enum class OutcomeType { Ordinal }; +enum class MeanLeafModelType { + GaussianConstant, + GaussianUnivariateRegression, + GaussianMultivariateRegression, + LogLinearVariance, + CloglogOrdinal +}; + struct BARTData { // Train set covariates double* X_train = nullptr; @@ -89,6 +98,11 @@ struct BARTConfig { std::vector var_weights_mean; // variable weights for mean forest splits (should be same length as number of covariates in the dataset) bool sample_sigma2_leaf_mean = false; // whether to sample mean forest leaf scale (if false, it will be fixed at sigma2_mean_init) std::vector sweep_update_indices_mean; // indices of trees to update in a given sweep (should be subset of [0, num_trees - 1]) + MeanLeafModelType mean_leaf_model_type; // leaf model type for mean forest + int num_classes_cloglog = 0; // number of classes for cloglog ordinal leaf model (should be set if mean_leaf_model_type = CloglogOrdinal) + double cloglog_leaf_prior_shape = 2.0; // shape parameter for cloglog ordinal leaf model prior + double cloglog_leaf_prior_scale = 2.0; // scale parameter for cloglog ordinal leaf model prior + double cloglog_cutpoint_0 = 0.0; // Fixed value of the first log-scale cutpoint for the cloglog model (defaults to 0 for identifiability) // Variance forest parameters int num_trees_variance = 0; // number of trees in the variance forest @@ -136,6 +150,9 @@ struct BARTSamples { // Pointer to sampled variance forests std::unique_ptr variance_forests; + // Posterior samples of cloglog cutpoint parameters (num_samples x num_classes - 1, stored column-major) + std::vector cloglog_cutpoint_samples; + // TODO: Pointer to random effects samples ... // Metadata about the samples (e.g., number of samples, burn-in, etc.) could be added here as needed diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h index 14fb1bb9..975844ed 100644 --- a/include/stochtree/bart_sampler.h +++ b/include/stochtree/bart_sampler.h @@ -14,9 +14,12 @@ #include #include #include +#include #include #include +#include #include +#include "stochtree/ordinal_sampler.h" namespace StochTree { @@ -39,23 +42,144 @@ class BARTSampler { bool initialized_ = false; /*! Internal sample runner function */ - void RunOneIteration(BARTSamples& samples, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample); + void RunOneIteration(BARTSamples& samples, bool gfr, bool keep_sample); + + /*! Initialization visitor */ + struct MeanForestInitVisitor { + BARTSampler& sampler; + BARTSamples& samples; + void operator()(GaussianConstantLeafModel& model) { + sampler.mean_forest_ = std::make_unique(sampler.config_.num_trees_mean, sampler.config_.leaf_dim_mean, sampler.config_.leaf_constant_mean, sampler.config_.exponentiated_leaf_mean); + samples.mean_forests = std::make_unique(sampler.config_.num_trees_mean, sampler.config_.leaf_dim_mean, sampler.config_.leaf_constant_mean, sampler.config_.exponentiated_leaf_mean); + sampler.mean_forest_tracker_ = std::make_unique(sampler.forest_dataset_->GetCovariates(), sampler.config_.feature_types, sampler.config_.num_trees_mean, sampler.data_.n_train); + sampler.tree_prior_mean_ = std::make_unique(sampler.config_.alpha_mean, sampler.config_.beta_mean, sampler.config_.min_samples_leaf_mean, sampler.config_.max_depth_mean); + sampler.mean_forest_->SetLeafValue(sampler.init_val_mean_ / sampler.config_.num_trees_mean); + UpdateResidualEntireForest(*sampler.mean_forest_tracker_, *sampler.forest_dataset_, *sampler.residual_, sampler.mean_forest_.get(), !sampler.config_.leaf_constant_mean, std::minus()); + sampler.mean_forest_tracker_->UpdatePredictions(sampler.mean_forest_.get(), *sampler.forest_dataset_.get()); + sampler.has_mean_forest_ = true; + } + void operator()(GaussianUnivariateRegressionLeafModel& model) { + // TODO ... + } + void operator()(GaussianMultivariateRegressionLeafModel& model) { + // TODO ... + } + void operator()(CloglogOrdinalLeafModel& model) { + // TODO ... + } + }; + + /*! GFR iteration visitor */ + struct GFROneIterationVisitor { + BARTSampler& sampler; + BARTSamples& samples; + bool keep_sample; + void operator()(GaussianConstantLeafModel& model) { + GFRSampleOneIter( + *sampler.mean_forest_, *sampler.mean_forest_tracker_, *samples.mean_forests, model, + *sampler.forest_dataset_, *sampler.residual_, *sampler.tree_prior_mean_, sampler.rng_, + sampler.config_.var_weights_mean, sampler.config_.sweep_update_indices_mean, sampler.global_variance_, sampler.config_.feature_types, + sampler.config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/true, + /*num_features_subsample=*/sampler.config_.num_features_subsample_mean, sampler.config_.num_threads); + } + void operator()(GaussianUnivariateRegressionLeafModel& model) { + GFRSampleOneIter( + *sampler.mean_forest_, *sampler.mean_forest_tracker_, *samples.mean_forests, model, + *sampler.forest_dataset_, *sampler.residual_, *sampler.tree_prior_mean_, sampler.rng_, + sampler.config_.var_weights_mean, sampler.config_.sweep_update_indices_mean, sampler.global_variance_, sampler.config_.feature_types, + sampler.config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/true, + /*num_features_subsample=*/sampler.config_.num_features_subsample_mean, sampler.config_.num_threads); + } + void operator()(GaussianMultivariateRegressionLeafModel& model) { + // TODO ... + } + void operator()(CloglogOrdinalLeafModel& model) { + GFRSampleOneIter( + *sampler.mean_forest_, *sampler.mean_forest_tracker_, *samples.mean_forests, model, + *sampler.forest_dataset_, *sampler.residual_, *sampler.tree_prior_mean_, sampler.rng_, + sampler.config_.var_weights_mean, sampler.config_.sweep_update_indices_mean, sampler.global_variance_, sampler.config_.feature_types, + sampler.config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/false, + /*num_features_subsample=*/sampler.config_.num_features_subsample_mean, sampler.config_.num_threads); + } + }; + + /*! MCMC iteration visitor */ + struct MCMCOneIterationVisitor { + BARTSampler& sampler; + BARTSamples& samples; + bool keep_sample; + void operator()(GaussianConstantLeafModel& model) { + MCMCSampleOneIter( + *sampler.mean_forest_, *sampler.mean_forest_tracker_, *samples.mean_forests, model, + *sampler.forest_dataset_, *sampler.residual_, *sampler.tree_prior_mean_, sampler.rng_, + sampler.config_.var_weights_mean, sampler.config_.sweep_update_indices_mean, sampler.global_variance_, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/true, + /*num_threads=*/sampler.config_.num_threads); + } + void operator()(GaussianUnivariateRegressionLeafModel& model) { + MCMCSampleOneIter( + *sampler.mean_forest_, *sampler.mean_forest_tracker_, *samples.mean_forests, model, + *sampler.forest_dataset_, *sampler.residual_, *sampler.tree_prior_mean_, sampler.rng_, + sampler.config_.var_weights_mean, sampler.config_.sweep_update_indices_mean, sampler.global_variance_, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/true, + /*num_threads=*/sampler.config_.num_threads); + } + void operator()(GaussianMultivariateRegressionLeafModel& model) { + // TODO ... + } + void operator()(CloglogOrdinalLeafModel& model) { + MCMCSampleOneIter( + *sampler.mean_forest_, *sampler.mean_forest_tracker_, *samples.mean_forests, model, + *sampler.forest_dataset_, *sampler.residual_, *sampler.tree_prior_mean_, sampler.rng_, + sampler.config_.var_weights_mean, sampler.config_.sweep_update_indices_mean, sampler.global_variance_, /*keep_forest=*/keep_sample, + /*pre_initialized=*/true, /*backfitting=*/false, + /*num_threads=*/sampler.config_.num_threads); + } + }; + + /*! Mean forest leaf scale update visitor */ + struct ScaleUpdateVisitor { + BARTSampler& sampler; + double leaf_scale; + void operator()(GaussianConstantLeafModel& model) { + model.SetScale(leaf_scale); + } + void operator()(GaussianUnivariateRegressionLeafModel& model) { + model.SetScale(leaf_scale); + } + void operator()(GaussianMultivariateRegressionLeafModel& model) { + // No-op for multivariate regression leaf model since scale is a vector + } + void operator()(CloglogOrdinalLeafModel& model) { + // No-op for cloglog ordinal leaf model since scale is not a variance parameter + } + }; /*! Internal reference to config and data state */ BARTConfig& config_; BARTData& data_; + /*! Leaf model for mean and variance forests */ + std::variant mean_leaf_model_; + LogLinearVarianceLeafModel variance_leaf_model_; + /*! Mean forest state */ std::unique_ptr mean_forest_; std::unique_ptr mean_forest_tracker_; std::unique_ptr tree_prior_mean_; bool has_mean_forest_ = false; + double init_val_mean_; + std::unique_ptr ordinal_sampler_; /*! Variance forest state */ std::unique_ptr variance_forest_; std::unique_ptr variance_forest_tracker_; std::unique_ptr tree_prior_variance_; bool has_variance_forest_ = false; + double init_val_variance_; /*! Dataset */ std::unique_ptr residual_; diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 7156eb06..4f6394ac 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -82,6 +82,11 @@ StochTree::BARTConfig convert_list_to_config(cpp11::list config) { output.b_sigma2_mean = get_config_scalar_default(config, "b_sigma2_mean", -1.0); output.sigma2_mean_init = get_config_scalar_default(config, "sigma2_mean_init", -1.0); output.sample_sigma2_leaf_mean = get_config_scalar_default(config, "sample_sigma2_leaf_mean", false); + output.mean_leaf_model_type = static_cast(get_config_scalar_default(config, "mean_leaf_model_type", 0)); + output.num_classes_cloglog = get_config_scalar_default(config, "num_classes_cloglog", 0); + output.cloglog_leaf_prior_shape = get_config_scalar_default(config, "cloglog_leaf_prior_shape", 2.0); + output.cloglog_leaf_prior_scale = get_config_scalar_default(config, "cloglog_leaf_prior_scale", 2.0); + output.cloglog_cutpoint_0 = get_config_scalar_default(config, "cloglog_cutpoint_0", 0.0); // Variance forest parameters output.num_trees_variance = get_config_scalar_default(config, "num_trees_variance", 0); @@ -174,6 +179,11 @@ cpp11::writable::list convert_bart_results_to_list(StochTree::BARTSamples& bart_ : R_NilValue; output.push_back(cpp11::named_arg("leaf_scale_samples") = leaf_scale_sexp); + SEXP cloglog_cutpoints_sexp = !bart_samples.cloglog_cutpoint_samples.empty() + ? static_cast(cpp11::writable::doubles(bart_samples.cloglog_cutpoint_samples.begin(), bart_samples.cloglog_cutpoint_samples.end())) + : R_NilValue; + output.push_back(cpp11::named_arg("cloglog_cutpoint_samples") = cloglog_cutpoints_sexp); + // Metadata about the model that was sampled double y_bar_sexp = bart_samples.y_bar; output.push_back(cpp11::named_arg("y_bar") = y_bar_sexp); diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 27ac3a52..66764757 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -12,7 +12,7 @@ namespace StochTree { -BARTSampler::BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data) : config_{config}, data_{data} { +BARTSampler::BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& data) : config_{config}, data_{data}, mean_leaf_model_(GaussianConstantLeafModel(0.0)), variance_leaf_model_(0.0, 0.0) { InitializeState(samples); } @@ -52,21 +52,34 @@ void BARTSampler::InitializeState(BARTSamples& samples) { double y_var = M2 / data_.n_train; // Standardization and calibration for mean forests - double init_val_mean; if (config_.num_trees_mean > 0) { + // Initialize leaf model + if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianConstant) { + mean_leaf_model_ = GaussianConstantLeafModel(config_.sigma2_mean_init); + } else if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianUnivariateRegression) { + mean_leaf_model_ = GaussianUnivariateRegressionLeafModel(config_.sigma2_mean_init); + } else if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianMultivariateRegression) { + // TODO + // mean_leaf_model_ = GaussianMultivariateRegressionLeafModel(...); + } else if (config_.mean_leaf_model_type == MeanLeafModelType::CloglogOrdinal) { + mean_leaf_model_ = CloglogOrdinalLeafModel(config_.cloglog_leaf_prior_shape, config_.cloglog_leaf_prior_scale); + } else { + Log::Fatal("Unsupported leaf model type for mean forest"); + } + if (config_.link_function == LinkFunction::Probit) { samples.y_std = 1.0; samples.y_bar = norm_inv_cdf(y_mean); - init_val_mean = 0.0; + init_val_mean_ = 0.0; } else { if (config_.standardize_outcome) { samples.y_bar = y_mean; samples.y_std = std::sqrt(y_var); - init_val_mean = 0.0; + init_val_mean_ = 0.0; } else { samples.y_bar = 0.0; samples.y_std = 1.0; - init_val_mean = y_mean; + init_val_mean_ = y_mean; } } if (config_.sigma2_mean_init < 0.0) { @@ -88,7 +101,6 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } // Calibration for variance forests - double init_val_variance; if (config_.num_trees_variance > 0) { // NOTE: calibration only works for standardized outcomes if (config_.shape_variance_forest <= 0.0 || config_.scale_variance_forest <= 0.0) { @@ -103,9 +115,9 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } } if (config_.standardize_outcome) { - init_val_variance = 1.0; + init_val_variance_ = 1.0; } else { - init_val_variance = y_var; + init_val_variance_ = y_var; } } @@ -118,23 +130,17 @@ void BARTSampler::InitializeState(BARTSamples& samples) { // Initialize mean forest state (if present) if (config_.num_trees_mean > 0) { - mean_forest_ = std::make_unique(config_.num_trees_mean, config_.leaf_dim_mean, config_.leaf_constant_mean, config_.exponentiated_leaf_mean); - samples.mean_forests = std::make_unique(config_.num_trees_mean, config_.leaf_dim_mean, config_.leaf_constant_mean, config_.exponentiated_leaf_mean); - mean_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_mean, data_.n_train); - tree_prior_mean_ = std::make_unique(config_.alpha_mean, config_.beta_mean, config_.min_samples_leaf_mean, config_.max_depth_mean); - mean_forest_->SetLeafValue(init_val_mean / config_.num_trees_mean); - UpdateResidualEntireForest(*mean_forest_tracker_, *forest_dataset_, *residual_, mean_forest_.get(), !config_.leaf_constant_mean, std::minus()); - mean_forest_tracker_->UpdatePredictions(mean_forest_.get(), *forest_dataset_.get()); - has_mean_forest_ = true; + std::visit(MeanForestInitVisitor{*this, samples}, mean_leaf_model_); } // Initialize variance forest state (if present) if (config_.num_trees_variance > 0) { + variance_leaf_model_ = LogLinearVarianceLeafModel(config_.shape_variance_forest, config_.scale_variance_forest); variance_forest_ = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); samples.variance_forests = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); variance_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_variance, data_.n_train); tree_prior_variance_ = std::make_unique(config_.alpha_variance, config_.beta_variance, config_.min_samples_leaf_variance, config_.max_depth_variance); - variance_forest_->SetLeafValue(init_val_variance / config_.num_trees_variance); + variance_forest_->SetLeafValue(init_val_variance_ / config_.num_trees_variance); variance_forest_tracker_->UpdatePredictions(variance_forest_.get(), *forest_dataset_.get()); has_variance_forest_ = true; } @@ -154,6 +160,38 @@ void BARTSampler::InitializeState(BARTSamples& samples) { // RNG rng_ = std::mt19937(config_.random_seed >= 0 ? config_.random_seed : std::random_device{}()); + // Cloglog state + if (config_.link_function == LinkFunction::Cloglog) { + // Latent variable (Z in Alam et al (2025) notation) + forest_dataset_->AddAuxiliaryDimension(data_.n_train); + // Forest predictions (eta in Alam et al (2025) notation) + forest_dataset_->AddAuxiliaryDimension(data_.n_train); + // Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation) + forest_dataset_->AddAuxiliaryDimension(config_.num_classes_cloglog - 1); + // Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation) + // This auxiliary series is designed so that the element stored at position `i` + // corresponds to the sum of all exponentiated gamma_j values for j < i. + // It has cloglog_num_categories elements instead of cloglog_num_categories - 1 because + // even the largest categorical index has a valid value of sum_{j < i} exp(gamma_j) + forest_dataset_->AddAuxiliaryDimension(config_.num_classes_cloglog); + + // Set initial values for auxiliary data + // Initialize latent variables to zero (slot 0) + for (int i = 0; i < data_.n_train; i++) { + forest_dataset_->SetAuxiliaryDataValue(0, i, 0.0); + } + // Initialize forest predictions to zero (slot 1) + for (int i = 0; i < data_.n_train; i++) { + forest_dataset_->SetAuxiliaryDataValue(1, i, 0.0); + } + // Initialize log-scale cutpoints to 0 + for (int i = 0; i < config_.num_classes_cloglog - 1; i++) { + forest_dataset_->SetAuxiliaryDataValue(2, i, 0.0); + } + // Convert to cumulative exponentiated cutpoints directly in C++ + ordinal_sampler_->UpdateCumulativeExpSums(*forest_dataset_); + } + // Other internal model state global_variance_ = config_.sigma2_global_init; leaf_scale_ = config_.sigma2_mean_init; @@ -174,10 +212,8 @@ void BARTSampler::run_gfr(BARTSamples& samples, int num_gfr, bool keep_gfr) { } // TODO: dispatch correct leaf model and variance model based on config; currently hardcoded to Gaussian constant-leaf and homoskedastic variance - std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); - std::unique_ptr variance_leaf_model_ptr = std::make_unique(config_.shape_variance_forest, config_.scale_variance_forest); for (int i = 0; i < num_gfr; i++) { - RunOneIteration(samples, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/true, /*keep_sample=*/keep_gfr); + RunOneIteration(samples, /*gfr=*/true, /*keep_sample=*/keep_gfr); } } @@ -197,15 +233,13 @@ void BARTSampler::run_mcmc(BARTSamples& samples, int num_burnin, int keep_every, } // Create leaf models and pass them to the RunOneIteration function; these are updated in place and will reflect the current state of the leaf scale parameters (if they are being sampled) - std::unique_ptr mean_leaf_model_ptr = std::make_unique(leaf_scale_); - std::unique_ptr variance_leaf_model_ptr = std::make_unique(config_.shape_variance_forest, config_.scale_variance_forest); bool keep_forest = false; for (int i = 0; i < num_burnin + keep_every * num_mcmc; i++) { if (i >= num_burnin && (i - num_burnin) % keep_every == 0) keep_forest = true; else keep_forest = false; - RunOneIteration(samples, mean_leaf_model_ptr.get(), variance_leaf_model_ptr.get(), /*gfr=*/false, /*keep_sample=*/keep_forest); + RunOneIteration(samples, /*gfr=*/false, /*keep_sample=*/keep_forest); } } @@ -225,30 +259,19 @@ void BARTSampler::postprocess_samples(BARTSamples& samples) { } } -void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafModel* mean_leaf_model, LogLinearVarianceLeafModel* variance_leaf_model, bool gfr, bool keep_sample) { +void BARTSampler::RunOneIteration(BARTSamples& samples, bool gfr, bool keep_sample) { if (has_mean_forest_) { if (gfr) { - GFRSampleOneIter( - *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, - *forest_dataset_, *residual_, *tree_prior_mean_, rng_, - config_.var_weights_mean, config_.sweep_update_indices_mean, global_variance_, config_.feature_types, - config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, - /*pre_initialized=*/true, /*backfitting=*/true, - /*num_features_subsample=*/config_.num_features_subsample_mean, config_.num_threads); + std::visit(GFROneIterationVisitor{*this, samples, keep_sample}, mean_leaf_model_); } else { - MCMCSampleOneIter( - *mean_forest_, *mean_forest_tracker_, *samples.mean_forests, *mean_leaf_model, - *forest_dataset_, *residual_, *tree_prior_mean_, rng_, - config_.var_weights_mean, config_.sweep_update_indices_mean, global_variance_, /*keep_forest=*/keep_sample, - /*pre_initialized=*/true, /*backfitting=*/true, - /*num_threads=*/config_.num_threads); + std::visit(MCMCOneIterationVisitor{*this, samples, keep_sample}, mean_leaf_model_); } } if (has_variance_forest_) { if (gfr) { GFRSampleOneIter( - *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, + *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, variance_leaf_model_, *forest_dataset_, *residual_, *tree_prior_variance_, rng_, config_.var_weights_variance, config_.sweep_update_indices_variance, global_variance_, config_.feature_types, config_.cutpoint_grid_size, /*keep_forest=*/keep_sample, @@ -256,7 +279,7 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafMode /*num_features_subsample=*/config_.num_features_subsample_variance, config_.num_threads); } else { MCMCSampleOneIter( - *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, *variance_leaf_model, + *variance_forest_, *variance_forest_tracker_, *samples.variance_forests, variance_leaf_model_, *forest_dataset_, *residual_, *tree_prior_variance_, rng_, config_.var_weights_variance, config_.sweep_update_indices_variance, global_variance_, /*keep_forest=*/keep_sample, /*pre_initialized=*/true, /*backfitting=*/false, @@ -269,6 +292,10 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafMode residual_->GetData().data(), samples.y_bar, data_.n_train); } + if (config_.link_function == LinkFunction::Cloglog) { + // TODO + } + if (sample_sigma2_global_) { global_variance_ = var_model_->SampleVarianceParameter( residual_->GetData(), config_.a_sigma2_global, config_.b_sigma2_global, rng_); @@ -277,7 +304,24 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafMode if (sample_sigma2_leaf_) { leaf_scale_ = leaf_scale_model_->SampleVarianceParameter( mean_forest_.get(), config_.a_sigma2_mean, config_.b_sigma2_mean, rng_); - mean_leaf_model->SetScale(leaf_scale_); + std::visit(ScaleUpdateVisitor{*this, leaf_scale_}, mean_leaf_model_); + } + + // Gibbs updates for the cloglog model + if (config_.link_function == LinkFunction::Cloglog) { + // Update auxiliary data to current forest predictions + for (int i = 0; i < data_.n_train; i++) { + forest_dataset_->SetAuxiliaryDataValue(1, i, mean_forest_tracker_->GetSamplePrediction(i)); + } + + // Sample latent z_i's using truncated exponential + ordinal_sampler_->UpdateLatentVariables(*forest_dataset_, residual_->GetData(), rng_); + + // Sample gamma parameters (cutpoints) + ordinal_sampler_->UpdateGammaParams(*forest_dataset_, residual_->GetData(), config_.cloglog_leaf_prior_shape, config_.cloglog_leaf_prior_scale, config_.cloglog_cutpoint_0, rng_); + + // Update cumulative sum of exp(gamma) values + ordinal_sampler_->UpdateCumulativeExpSums(*forest_dataset_); } if (keep_sample) { @@ -287,7 +331,23 @@ void BARTSampler::RunOneIteration(BARTSamples& samples, GaussianConstantLeafMode if (has_mean_forest_) { double* mean_forest_preds_train = mean_forest_tracker_->GetSumPredictions(); samples.mean_forest_predictions_train.insert(samples.mean_forest_predictions_train.end(), - mean_forest_preds_train, mean_forest_preds_train + samples.num_train); + mean_forest_preds_train, + mean_forest_preds_train + samples.num_train); + } + if (has_variance_forest_) { + double* variance_forest_preds_train = variance_forest_tracker_->GetSumPredictions(); + samples.variance_forest_predictions_train.insert(samples.variance_forest_predictions_train.end(), + variance_forest_preds_train, + variance_forest_preds_train + samples.num_train); + } + if (config_.link_function == LinkFunction::Cloglog) { + // Store cutpoint samples + std::vector cloglog_cutpoints(config_.num_classes_cloglog - 1); + for (int i = 0; i < config_.num_classes_cloglog - 1; i++) { + cloglog_cutpoints[i] = forest_dataset_->GetAuxiliaryDataValue(2, i); + } + samples.cloglog_cutpoint_samples.insert(samples.cloglog_cutpoint_samples.end(), + cloglog_cutpoints.data(), cloglog_cutpoints.data() + cloglog_cutpoints.size()); } } } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 72995b34..e8cec5d2 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -2198,6 +2198,11 @@ inline StochTree::BARTConfig convert_dict_to_bart_config(py::dict config_dict) { output.b_sigma2_mean = get_config_scalar_default(config_dict, "b_sigma2_mean", -1.0); output.sigma2_mean_init = get_config_scalar_default(config_dict, "sigma2_mean_init", -1.0); output.sample_sigma2_leaf_mean = get_config_scalar_default(config_dict, "sample_sigma2_leaf_mean", false); + output.mean_leaf_model_type = static_cast(get_config_scalar_default(config_dict, "mean_leaf_model_type", 0)); + output.num_classes_cloglog = get_config_scalar_default(config_dict, "num_classes_cloglog", 0); + output.cloglog_leaf_prior_shape = get_config_scalar_default(config_dict, "cloglog_leaf_prior_shape", 2.0); + output.cloglog_leaf_prior_scale = get_config_scalar_default(config_dict, "cloglog_leaf_prior_scale", 2.0); + output.cloglog_cutpoint_0 = get_config_scalar_default(config_dict, "cloglog_cutpoint_0", 0.0); // Variance forest parameters output.num_trees_variance = get_config_scalar_default(config_dict, "num_trees_variance", 0); @@ -2387,6 +2392,16 @@ inline py::dict convert_bart_results_to_dict( output["leaf_scale_samples"] = array; } + // Cloglog cutpoint samples + if (results_raw.cloglog_cutpoint_samples.empty()) { + output["cloglog_cutpoint_samples"] = py::none(); + } else { + auto input_vec = results_raw.cloglog_cutpoint_samples; + py::array_t array(input_vec.size()); + std::copy(input_vec.begin(), input_vec.end(), array.mutable_data()); + output["cloglog_cutpoint_samples"] = array; + } + // Unpack scalars output["y_bar"] = results_raw.y_bar; output["y_std"] = results_raw.y_std; diff --git a/stochtree/bart.py b/stochtree/bart.py index 67eee8c5..4f4e792a 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -214,6 +214,8 @@ def sample( - **keep_vars** (*list* or *np.array*): Variable names or column indices to include in the mean forest. Defaults to ``None``. - **drop_vars** (*list* or *np.array*): Variable names or column indices to exclude from the mean forest. Defaults to ``None``. Ignored if ``keep_vars`` is also set. - **num_features_subsample** (*int*): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. + - **cloglog_leaf_prior_shape** (*float*): Shape parameter for the prior on leaf parameters in a cloglog ordinal leaf model. Defaults to ``2.0``. + - **cloglog_leaf_prior_scale** (*float*): Scale parameter for the prior on leaf parameters in a cloglog ordinal leaf model. Defaults to ``2.0``. **variance_forest_params keys** @@ -276,6 +278,8 @@ def sample( "keep_vars": None, "drop_vars": None, "num_features_subsample": None, + "cloglog_leaf_prior_shape": 2.0, + "cloglog_leaf_prior_scale": 2.0, } mean_forest_params_updated = _preprocess_params( mean_forest_params_default, mean_forest_params @@ -347,6 +351,8 @@ def sample( num_features_subsample_mean = mean_forest_params_updated[ "num_features_subsample" ] + cloglog_leaf_prior_shape = mean_forest_params_updated["cloglog_leaf_prior_shape"] + cloglog_leaf_prior_scale = mean_forest_params_updated["cloglog_leaf_prior_scale"] # 3. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] @@ -1096,6 +1102,31 @@ def sample( elif self.rfx_model_spec == "intercept_only": if rfx_basis_test is None: rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + + # Set variance leaf model type (currently only one option) + leaf_model_variance_forest = 3 + leaf_dimension_variance = 1 + + # Determine the mean forest leaf model type + if link_is_cloglog and not self.has_basis: + leaf_model_mean_forest = 4 + leaf_dimension_mean = 1 + elif not self.has_basis: + leaf_model_mean_forest = 0 + leaf_dimension_mean = 1 + elif self.num_basis == 1: + leaf_model_mean_forest = 1 + leaf_dimension_mean = 1 + else: + leaf_model_mean_forest = 2 + leaf_dimension_mean = self.num_basis + + # Determine cloglog number of classes + if link_is_cloglog: + unique_outcomes = np.sort(np.unique(y_train)) + cloglog_num_categories = int(np.max(y_train - np.min(unique_outcomes))) + 1 + else: + cloglog_num_categories = 0 if run_cpp: # Arrange all config in a large python dictionary @@ -1125,6 +1156,11 @@ def sample( "b_sigma2_mean": b_leaf, "sigma2_mean_init": -1.0, "sample_sigma2_leaf_mean": sample_sigma2_leaf, + "mean_leaf_model_type": leaf_model_mean_forest, + "num_classes_cloglog": cloglog_num_categories, + "cloglog_leaf_prior_shape": cloglog_leaf_prior_shape, + "cloglog_leaf_prior_scale": cloglog_leaf_prior_scale, + "cloglog_cutpoint_0": 0, "num_trees_variance": num_trees_variance, "leaf_prior_calibration_param": a_0, "shape_variance_forest": a_forest, @@ -1198,7 +1234,7 @@ def sample( self.y_bar = bart_results["y_bar"] self.y_std = bart_results["y_std"] self.sigma2_init = bart_results["sigma2_init"] - self.sigma2_leaf_init = bart_results["sigma2_leaf_init"] if self.include_mean_forest else None + self.sigma2_leaf_init = bart_results["sigma2_mean_init"] if self.include_mean_forest else None self.b_leaf = bart_results["b_sigma2_mean"] if self.include_mean_forest else None self.shape_variance_forest = bart_results["shape_variance_forest"] if self.include_variance_forest else None self.scale_variance_forest = bart_results["scale_variance_forest"] if self.include_variance_forest else None @@ -1233,6 +1269,9 @@ def sample( self.global_var_samples = bart_results["global_var_samples"] * self.y_std * self.y_std if self.sample_sigma2_leaf: self.leaf_scale_samples = bart_results["leaf_scale_samples"] + if link_is_cloglog: + self.cloglog_num_categories = cloglog_num_categories + self.cloglog_cutpoint_samples = bart_results["cloglog_cutpoint_samples"].reshape(cloglog_num_categories - 1, bart_results["num_samples"], order="F") # Unpack other model metadata self.num_samples = bart_results["num_samples"] @@ -1339,8 +1378,10 @@ def sample( cloglog_cutpoint_0 = 0.0 # Set shape and rate parameters for conditional gamma model - cloglog_forest_shape = 2.0 - cloglog_forest_rate = 2.0 + if cloglog_leaf_prior_shape is None: + cloglog_forest_shape = 2.0 + if cloglog_leaf_prior_scale is None: + cloglog_forest_rate = 2.0 else: # Standardize if requested if self.standardize: @@ -1440,7 +1481,6 @@ def sample( b_forest = 1.0 self.shape_variance_forest = a_forest self.scale_variance_forest = b_forest - self.sigma2_leaf_init = bart_results["sigma2_leaf_init"] if self.include_mean_forest else None # Set up random effects structures if self.has_rfx: @@ -1553,24 +1593,6 @@ def sample( cpp_rng = RNG(random_seed) self.rng = np.random.default_rng(random_seed) - # Set variance leaf model type (currently only one option) - leaf_model_variance_forest = 3 - leaf_dimension_variance = 1 - - # Determine the mean forest leaf model type - if link_is_cloglog and not self.has_basis: - leaf_model_mean_forest = 4 - leaf_dimension_mean = 1 - elif not self.has_basis: - leaf_model_mean_forest = 0 - leaf_dimension_mean = 1 - elif self.num_basis == 1: - leaf_model_mean_forest = 1 - leaf_dimension_mean = 1 - else: - leaf_model_mean_forest = 2 - leaf_dimension_mean = self.num_basis - # Sampling data structures global_model_config = GlobalModelConfig(global_error_variance=current_sigma2) if self.include_mean_forest: From 62a528ae7b17fa5b0ca2231da6f27da56a4a61d8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Apr 2026 17:39:51 -0500 Subject: [PATCH 56/64] Working cloglog implementation (with two R-path bug fixes) --- R/bart.R | 34 ++- ...hmark_cpp_vs_py_sampler_ordinal_cloglog.py | 173 +++++++++++++++ ...nchmark_cpp_vs_r_sampler_ordinal_cloglog.R | 209 ++++++++++++++++++ include/stochtree/bart_sampler.h | 9 +- src/bart_sampler.cpp | 51 ++++- src/forest.cpp | 2 +- 6 files changed, 460 insertions(+), 18 deletions(-) create mode 100644 debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py create mode 100644 debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R diff --git a/R/bart.R b/R/bart.R index e557e557..11bdfb01 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1199,7 +1199,11 @@ bart <- function( bart_results <- bart_sample_cpp( X_train = X_train, - y_train = y_train, + y_train = if (link_is_cloglog) { + as.numeric(y_train - min(y_train)) + } else { + y_train + }, X_test = if (exists("X_test")) X_test else NULL, n_train = nrow(X_train), n_test = if (!is.null(X_test)) nrow(X_test) else 0L, @@ -1353,9 +1357,9 @@ bart <- function( cloglog_num_categories - 1, bart_results[["num_samples"]] ) - result[["cloglog_cutpoint_samples"]] <- t(bart_results[[ + result[["cloglog_cutpoint_samples"]] <- bart_results[[ "cloglog_cutpoint_samples" - ]]) + ]] } class(result) <- "bartmodel" @@ -3113,8 +3117,12 @@ predict.bartmodel <- function( mean_forest_probabilities[, j, ] <- (1 - exp( -exp( - mean_forest_predictions + - cloglog_cutpoint_samples[j, ] + sweep( + mean_forest_predictions, + 2, + cloglog_cutpoint_samples[j, ], + "+" + ) ) )) } else if (j == cloglog_num_categories) { @@ -3127,15 +3135,23 @@ predict.bartmodel <- function( } else { mean_forest_probabilities[, j, ] <- (exp( -exp( - mean_forest_predictions + - cloglog_cutpoint_samples[j - 1, ] + sweep( + mean_forest_predictions, + 2, + cloglog_cutpoint_samples[j - 1, ], + "+" + ) ) ) * (1 - exp( -exp( - mean_forest_predictions + - cloglog_cutpoint_samples[j, ] + sweep( + mean_forest_predictions, + 2, + cloglog_cutpoint_samples[j, ], + "+" + ) ) ))) } diff --git a/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py b/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py new file mode 100644 index 00000000..34a05885 --- /dev/null +++ b/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py @@ -0,0 +1,173 @@ +"""Benchmark: C++ sampler loop vs. Python sampler loop – ordinal cloglog BART. + +Compares runtime, mean Brier score, and mean RMSE-to-truth (vs. true class +probabilities) across run_cpp=True / False in BARTModel.sample(). + +DGP uses 4 ordinal categories with a cloglog link. +The latent step function f(X) is on the log-log scale, and each category +boundary (gamma_k) is fixed at log(k) for k = 1, 2, 3 so the four +cumulative probabilities are P(Y <= k | X) = 1 - exp(-exp(f(X) - gamma_k)). + +Usage: + conda activate stochtree-book # or: source venv/bin/activate + python debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py +""" + +import time +import numpy as np +from stochtree import BARTModel, OutcomeModel + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +rng = np.random.default_rng(1234) + +n = 2000 +p = 10 +X = rng.uniform(size=(n, p)) + +# Latent step function on the cloglog scale +f_X = ( + np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -2.0, 0.0) + + np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -0.5, 0.0) + + np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 0.5, 0.0) + + np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 1.0, 0.0) +) + +# Fixed log-scale cutpoints (gamma_k); K = 4 categories => K-1 = 3 cutpoints +# gamma_0 is fixed at 0 for identifiability; gamma_1 = log(2), gamma_2 = log(3) +K = 4 +gamma_true = np.array([0.0, np.log(2), np.log(3)]) + +# True cumulative probabilities: P(Y <= k | X) = 1 - exp(-exp(f_X - gamma_k)) +# Shape: (n, K-1) +cum_prob = 1.0 - np.exp(-np.exp(f_X[:, None] - gamma_true[None, :])) + +# True class probabilities: P(Y = k | X), shape (n, K) +p_X = np.column_stack([ + cum_prob[:, 0], + cum_prob[:, 1] - cum_prob[:, 0], + cum_prob[:, 2] - cum_prob[:, 1], + 1.0 - cum_prob[:, 2], +]) + +# Draw ordinal outcomes (0-indexed: 0, 1, 2, 3) +u = rng.uniform(size=n) +y = ( + (u > cum_prob[:, 0]).astype(int) + + (u > cum_prob[:, 1]).astype(int) + + (u > cum_prob[:, 2]).astype(int) +).astype(float) + +test_frac = 0.2 +n_test = round(test_frac * n) +n_train = n - n_test +test_inds = rng.choice(n, size=n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) + +X_train, X_test = X[train_inds], X[test_inds] +y_train, y_test = y[train_inds], y[test_inds] +p_test = p_X[test_inds] # (n_test, K) true class probabilities + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr = 10 +num_mcmc = 100 +num_trees = 200 +n_reps = 3 + +print( + f"K={K} n_train={n_train} n_test={n_test} p={p} " + f"num_trees={num_trees} num_gfr={num_gfr} num_mcmc={num_mcmc} reps={n_reps}\n" +) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + metrics +# --------------------------------------------------------------------------- +def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: + m = BARTModel() + t0 = time.perf_counter() + m.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=0, + num_mcmc=num_mcmc, + mean_forest_params={"num_trees": num_trees}, + general_params={ + "random_seed": seed, + "outcome_model": OutcomeModel(outcome="ordinal", link="cloglog"), + "sample_sigma2_global": False, + }, + run_cpp=run_cpp, + ) + elapsed = time.perf_counter() - t0 + + # predict() returns a dict; for ordinal probability scale the value for + # "y_hat" has shape (n_test, K, num_mcmc) + preds = m.predict(X=X_test, scale="probability") + p_hat = preds["y_hat"].mean(axis=2) # (n_test, K) posterior mean + + # Mean Brier score across all cells + brier = float(np.mean((p_hat - p_test) ** 2)) + # Per-class RMSE vs. true probs, averaged over classes + rmse_p = float(np.mean(np.sqrt(np.mean((p_hat - p_test) ** 2, axis=0)))) + + return {"elapsed": elapsed, "brier": brier, "rmse_p": rmse_p} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds = [1000 + i for i in range(1, n_reps + 1)] + +results_cpp = [] +results_py = [] + +print("Running C++ sampler (run_cpp=True)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_cpp.append(run_once(run_cpp=True, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + +print("\nRunning Python sampler (run_cpp=False)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_py.append(run_once(run_cpp=False, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +def summarise(results: list) -> dict: + elapsed = [r["elapsed"] for r in results] + brier = [r["brier"] for r in results] + rmse_p = [r["rmse_p"] for r in results] + return { + "elapsed_mean": float(np.mean(elapsed)), + "elapsed_sd": float(np.std(elapsed, ddof=1)), + "brier_mean": float(np.mean(brier)), + "rmse_p_mean": float(np.mean(rmse_p)), + } + +s_cpp = summarise(results_cpp) +s_py = summarise(results_py) +rows = [("cpp (run_cpp=True)", s_cpp), ("py (run_cpp=False)", s_py)] + +print("\n--- Results ---") +print( + f"{'Sampler':<22} {'Time (s)':>10} {'SD':>10} " + f"{'Brier':>12} {'RMSE (vs truth)':>15}" +) +print("-" * 75) +for label, s in rows: + print( + f"{label:<22} {s['elapsed_mean']:>10.3f} {s['elapsed_sd']:>10.3f}" + f" {s['brier_mean']:>12.4f} {s['rmse_p_mean']:>15.4f}" + ) + +speedup = s_py["elapsed_mean"] / s_cpp["elapsed_mean"] +print(f"\nSpeedup (py / cpp): {speedup:.2f}x") +print( + f"Brier delta (cpp - py): {s_cpp['brier_mean'] - s_py['brier_mean']:.4f}\n" + f"RMSE-p delta (cpp - py): {s_cpp['rmse_p_mean'] - s_py['rmse_p_mean']:.4f}" +) diff --git a/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R b/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R new file mode 100644 index 00000000..10938387 --- /dev/null +++ b/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R @@ -0,0 +1,209 @@ +## Benchmark: C++ sampler loop vs. R sampler loop – ordinal cloglog BART +## Compares runtime, mean Brier score, and mean RMSE-to-truth (vs. true class +## probabilities) across run_cpp = TRUE / FALSE in bart(). +## +## DGP uses 4 ordinal categories with a cloglog link. +## The latent step function f(X) is on the log-log scale, and each category +## boundary (gamma_k) is fixed at log(k) for k = 1, 2, 3 so the four +## cumulative probabilities are P(Y <= k | X) = 1 - exp(-exp(f(X) - gamma_k)). +## +## Usage: Rscript debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R +## or source() from an interactive session after devtools::load_all('.') +library(stochtree) + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +set.seed(1234) + +n <- 2000 +p <- 10 +X <- matrix(runif(n * p), ncol = p) + +# Latent step function on the cloglog scale +f_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * + (-2.0) + + ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (-0.5) + + ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (0.5) + + ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (1.0)) + +# Fixed log-scale cutpoints (gamma_k); K = 4 categories => K-1 = 3 cutpoints +# gamma_0 is fixed at 0 for identifiability; gamma_1 = log(2), gamma_2 = log(3) +K <- 4 +gamma_true <- c(0, log(2), log(3)) + +# True cumulative probabilities: P(Y <= k | X) = 1 - exp(-exp(f_X - gamma_k)) +# True class probabilities: P(Y = k | X) = P(Y <= k) - P(Y <= k-1) +cum_prob <- outer(f_X, gamma_true, function(f, g) 1 - exp(-exp(f - g))) +p_X <- cbind( + cum_prob[, 1], + cum_prob[, 2] - cum_prob[, 1], + cum_prob[, 3] - cum_prob[, 2], + 1 - cum_prob[, 3] +) # n x K matrix of true class probs + +# Draw ordinal outcomes (1-indexed: 1, 2, 3, 4) +u <- runif(n) +y <- as.integer(u > cum_prob[, 1]) + + as.integer(u > cum_prob[, 2]) + + as.integer(u > cum_prob[, 3]) + + 1L + +test_frac <- 0.2 +n_test <- round(test_frac * n) +n_train <- n - n_test +test_inds <- sort(sample(seq_len(n), n_test, replace = FALSE)) +train_inds <- setdiff(seq_len(n), test_inds) + +X_train <- X[train_inds, ] +X_test <- X[test_inds, ] +y_train <- y[train_inds] +y_test <- y[test_inds] +p_test <- p_X[test_inds, ] # n_test x K matrix of true class probabilities + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr <- 10 +num_mcmc <- 100 +num_trees <- 200 +n_reps <- 3 + +cat(sprintf( + "K=%d n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_mcmc=%d reps=%d\n\n", + K, + n_train, + n_test, + p, + num_trees, + num_gfr, + num_mcmc, + n_reps +)) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + metrics +# --------------------------------------------------------------------------- +run_once <- function(run_cpp, num_gfr, num_mcmc, seed = -1) { + t0 <- proc.time() + m <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = num_mcmc, + mean_forest_params = list(num_trees = num_trees), + general_params = list( + random_seed = seed, + outcome_model = OutcomeModel(outcome = "ordinal", link = "cloglog"), + sample_sigma2_global = FALSE + ), + run_cpp = run_cpp + ) + elapsed <- (proc.time() - t0)[["elapsed"]] + + # Posterior-mean predicted class probabilities on the test set + # predict() returns an n_test x K x num_mcmc array for ordinal outcomes + p_hat_arr <- predict( + m, + X = X_test, + type = "posterior", + terms = "y_hat", + scale = "probability" + ) + p_hat <- apply(p_hat_arr, c(1, 2), mean) # n_test x K posterior mean + + # Mean Brier score across classes (multi-class generalisation) + brier <- mean((p_hat - p_test)^2) + + # Per-class RMSE vs. true probabilities, then averaged + rmse_p <- mean(sqrt(colMeans((p_hat - p_test)^2))) + + list(elapsed = elapsed, brier = brier, rmse_p = rmse_p) +} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds <- 1000 + seq_len(n_reps) + +results_cpp <- vector("list", n_reps) +results_r <- vector("list", n_reps) + +cat("Running C++ sampler (run_cpp = TRUE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_cpp[[i]] <- run_once( + run_cpp = TRUE, + num_gfr = num_gfr, + num_mcmc = num_mcmc, + seed = seeds[i] + ) +} + +cat("\nRunning R sampler (run_cpp = FALSE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_r[[i]] <- run_once( + run_cpp = FALSE, + num_gfr = num_gfr, + num_mcmc = num_mcmc, + seed = seeds[i] + ) +} + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +summarise <- function(results, label) { + elapsed <- sapply(results, `[[`, "elapsed") + brier <- sapply(results, `[[`, "brier") + rmse_p <- sapply(results, `[[`, "rmse_p") + data.frame( + sampler = label, + elapsed_mean = mean(elapsed), + elapsed_sd = sd(elapsed), + brier_mean = mean(brier), + rmse_p_mean = mean(rmse_p), + row.names = NULL + ) +} + +res <- rbind( + summarise(results_cpp, "cpp (run_cpp=TRUE)"), + summarise(results_r, "R (run_cpp=FALSE)") +) + +cat("\n--- Results ---\n") +cat(sprintf( + "%-22s %10s %10s %12s %15s\n", + "Sampler", + "Time (s)", + "SD", + "Brier", + "RMSE (vs truth)" +)) +cat(strrep("-", 75), "\n") +for (i in seq_len(nrow(res))) { + cat(sprintf( + "%-22s %10.3f %10.3f %12.4f %15.4f\n", + res$sampler[i], + res$elapsed_mean[i], + res$elapsed_sd[i], + res$brier_mean[i], + res$rmse_p_mean[i] + )) +} + +speedup <- res$elapsed_mean[res$sampler == "R (run_cpp=FALSE)"] / + res$elapsed_mean[res$sampler == "cpp (run_cpp=TRUE)"] +cat(sprintf( + "\nSpeedup (R / cpp): %.2fx\n", + speedup +)) +cat(sprintf( + "Brier delta (cpp - R): %.4f\nRMSE-p delta (cpp - R): %.4f\n", + res$brier_mean[1] - res$brier_mean[2], + res$rmse_p_mean[1] - res$rmse_p_mean[2] +)) diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h index 975844ed..771e81c2 100644 --- a/include/stochtree/bart_sampler.h +++ b/include/stochtree/bart_sampler.h @@ -65,7 +65,14 @@ class BARTSampler { // TODO ... } void operator()(CloglogOrdinalLeafModel& model) { - // TODO ... + sampler.mean_forest_ = std::make_unique(sampler.config_.num_trees_mean, sampler.config_.leaf_dim_mean, sampler.config_.leaf_constant_mean, sampler.config_.exponentiated_leaf_mean); + samples.mean_forests = std::make_unique(sampler.config_.num_trees_mean, sampler.config_.leaf_dim_mean, sampler.config_.leaf_constant_mean, sampler.config_.exponentiated_leaf_mean); + sampler.mean_forest_tracker_ = std::make_unique(sampler.forest_dataset_->GetCovariates(), sampler.config_.feature_types, sampler.config_.num_trees_mean, sampler.data_.n_train); + sampler.tree_prior_mean_ = std::make_unique(sampler.config_.alpha_mean, sampler.config_.beta_mean, sampler.config_.min_samples_leaf_mean, sampler.config_.max_depth_mean); + sampler.mean_forest_->SetLeafValue(sampler.init_val_mean_ / sampler.config_.num_trees_mean); + UpdateResidualEntireForest(*sampler.mean_forest_tracker_, *sampler.forest_dataset_, *sampler.residual_, sampler.mean_forest_.get(), false, std::minus()); + sampler.mean_forest_tracker_->UpdatePredictions(sampler.mean_forest_.get(), *sampler.forest_dataset_.get()); + sampler.has_mean_forest_ = true; } }; diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 66764757..e93ff545 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -17,6 +18,24 @@ BARTSampler::BARTSampler(BARTSamples& samples, BARTConfig& config, BARTData& dat } void BARTSampler::InitializeState(BARTSamples& samples) { + // Validate y_train values match the expected support for discrete link functions + if (config_.link_function == LinkFunction::Probit) { + for (int i = 0; i < data_.n_train; i++) { + if (data_.y_train[i] != 0.0 && data_.y_train[i] != 1.0) { + Log::Fatal("Outcomes must be 0 or 1 for probit link function"); + } + } + } else if (config_.link_function == LinkFunction::Cloglog) { + for (int i = 0; i < data_.n_train; i++) { + if (std::floor(data_.y_train[i]) != data_.y_train[i]) { + Log::Fatal("Outcomes must be integers for cloglog link function"); + } + if (data_.y_train[i] < 0.0) { + Log::Fatal("Outcomes must be 0-indexed for cloglog link function; remap before calling the sampler"); + } + } + } + // Load data from BARTData object into ForestDataset object forest_dataset_ = std::make_unique(); forest_dataset_->AddCovariates(data_.X_train, data_.n_train, data_.p, /*row_major=*/false); @@ -68,18 +87,34 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } if (config_.link_function == LinkFunction::Probit) { + // Initialize forests to 0, no scaling, but offset by the probit transform of the mean outcome to improve mixing samples.y_std = 1.0; samples.y_bar = norm_inv_cdf(y_mean); init_val_mean_ = 0.0; + } else if (config_.link_function == LinkFunction::Cloglog) { + // Initialize forests to 0, no scaling or location shifting of the outcome + // Outcomes are expected to already be 0-indexed by the caller + samples.y_std = 1.0; + samples.y_bar = 0.0; + init_val_mean_ = 0.0; } else { - if (config_.standardize_outcome) { - samples.y_bar = y_mean; - samples.y_std = std::sqrt(y_var); - init_val_mean_ = 0.0; + if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianConstant) { + // Case 1: Constant leaf + if (config_.standardize_outcome) { + samples.y_bar = y_mean; + samples.y_std = std::sqrt(y_var); + init_val_mean_ = 0.0; + } else { + samples.y_bar = 0.0; + samples.y_std = 1.0; + init_val_mean_ = y_mean; + } + } else if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianUnivariateRegression) { + // Case 2: Univariate leaf regression + // TODO ... } else { - samples.y_bar = 0.0; - samples.y_std = 1.0; - init_val_mean_ = y_mean; + // Case 3: Multivariate leaf regression + // TODO ... } } if (config_.sigma2_mean_init < 0.0) { @@ -162,6 +197,8 @@ void BARTSampler::InitializeState(BARTSamples& samples) { // Cloglog state if (config_.link_function == LinkFunction::Cloglog) { + // Initialize the ordinal sampler + ordinal_sampler_ = std::make_unique(); // Latent variable (Z in Alam et al (2025) notation) forest_dataset_->AddAuxiliaryDimension(data_.n_train); // Forest predictions (eta in Alam et al (2025) notation) diff --git a/src/forest.cpp b/src/forest.cpp index 357777e2..ce258821 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -862,7 +862,7 @@ void initialize_forest_model_active_forest_cpp(cpp11::external_pointerNumObservations(); std::vector initial_preds(n, init_val); data->AddVarianceWeights(initial_preds.data(), n); - } else if (model_type == StochTree::ModelType::kLogLinearVariance) { + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { leaf_init_val = init_val / static_cast(num_trees); active_forest->SetLeafValue(leaf_init_val); UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), false, std::minus()); From 349de20e76c2c683de415e644233c6d49e161929 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Apr 2026 18:02:01 -0500 Subject: [PATCH 57/64] Fixed python cloglog test failures --- stochtree/bart.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stochtree/bart.py b/stochtree/bart.py index 4f4e792a..b8a81a4d 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1380,8 +1380,12 @@ def sample( # Set shape and rate parameters for conditional gamma model if cloglog_leaf_prior_shape is None: cloglog_forest_shape = 2.0 + else: + cloglog_forest_shape = cloglog_leaf_prior_shape if cloglog_leaf_prior_scale is None: cloglog_forest_rate = 2.0 + else: + cloglog_forest_rate = cloglog_leaf_prior_scale else: # Standardize if requested if self.standardize: From 481cb4a6e70d04b184b9eb40236aa541618daf2d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Apr 2026 18:31:05 -0500 Subject: [PATCH 58/64] Fix bug in python cloglog warm-start sampling --- src/bart_sampler.cpp | 1 - stochtree/bart.py | 19 +++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index e93ff545..bf00045a 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include diff --git a/stochtree/bart.py b/stochtree/bart.py index b8a81a4d..3690505a 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1122,11 +1122,7 @@ def sample( leaf_dimension_mean = self.num_basis # Determine cloglog number of classes - if link_is_cloglog: - unique_outcomes = np.sort(np.unique(y_train)) - cloglog_num_categories = int(np.max(y_train - np.min(unique_outcomes))) + 1 - else: - cloglog_num_categories = 0 + cloglog_num_categories = int(np.max(y_train - np.min(y_train))) + 1 if link_is_cloglog else 0 if run_cpp: # Arrange all config in a large python dictionary @@ -1190,7 +1186,8 @@ def sample( # Passing already-F-contiguous arrays causes pybind11 to return a view of # the original, which remains alive in this Python scope. X_train_cpp = np.asfortranarray(X_train_processed) - y_train_cpp = np.asfortranarray(y_train) + y_train_remapped = y_train - np.min(y_train) if link_is_cloglog else y_train + y_train_cpp = np.asfortranarray(y_train_remapped) X_test_cpp = np.asfortranarray(X_test_processed) if self.has_test else None basis_train_cpp = np.asfortranarray(leaf_basis_train) if self.has_basis else None basis_test_cpp = np.asfortranarray(leaf_basis_test) if self.has_basis and self.has_test else None @@ -1879,6 +1876,10 @@ def sample( residual_train, True, ) + if link_is_cloglog: + # ReconstituteFromForest corrupts the residual for cloglog + # (computes y - forest_preds instead of keeping category labels) + residual_train.update_data(resid_train[:, 0]) # Reset leaf scale if sample_sigma2_leaf: leaf_scale_double = self.leaf_scale_samples[ @@ -1934,6 +1935,9 @@ def sample( residual_train, True, ) + if link_is_cloglog: + # ReconstituteFromForest corrupts the residual for cloglog + residual_train.update_data(resid_train[:, 0]) # Reset leaf scale if sample_sigma2_leaf and previous_leaf_var_samples is not None: leaf_scale_double = previous_leaf_var_samples[ @@ -1998,6 +2002,9 @@ def sample( residual_train, True, ) + if link_is_cloglog: + # ReconstituteFromForest corrupts the residual for cloglog + residual_train.update_data(resid_train[:, 0]) # Reset mean forest leaf scale if sample_sigma2_leaf and previous_leaf_var_samples is not None: current_leaf_scale[0, 0] = sigma2_leaf From c9fbceeab4d692c10bb6915361cb724c186d5bc3 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Apr 2026 18:40:32 -0500 Subject: [PATCH 59/64] Added unit tests for warm start cloglog in python and updated cloglog debug scripts --- debug/benchmark_cpp_vs_py_sampler_cloglog.py | 2 +- ...chmark_cpp_vs_py_sampler_ordinal_cloglog.py | 2 +- debug/benchmark_cpp_vs_r_sampler_cloglog.R | 5 ++++- ...enchmark_cpp_vs_r_sampler_ordinal_cloglog.R | 5 ++++- test/python/test_bart.py | 18 ++++++++++++++++++ 5 files changed, 28 insertions(+), 4 deletions(-) diff --git a/debug/benchmark_cpp_vs_py_sampler_cloglog.py b/debug/benchmark_cpp_vs_py_sampler_cloglog.py index 67e63f09..9d6eacb4 100644 --- a/debug/benchmark_cpp_vs_py_sampler_cloglog.py +++ b/debug/benchmark_cpp_vs_py_sampler_cloglog.py @@ -72,7 +72,7 @@ def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: num_gfr=num_gfr, num_burnin=0, num_mcmc=num_mcmc, - mean_forest_params={"num_trees": num_trees}, + mean_forest_params={"num_trees": num_trees, "sample_sigma2_leaf": False}, general_params={ "random_seed": seed, "outcome_model": OutcomeModel(outcome="binary", link="cloglog"), diff --git a/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py b/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py index 34a05885..e6c568ed 100644 --- a/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py +++ b/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py @@ -95,7 +95,7 @@ def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: num_gfr=num_gfr, num_burnin=0, num_mcmc=num_mcmc, - mean_forest_params={"num_trees": num_trees}, + mean_forest_params={"num_trees": num_trees, "sample_sigma2_leaf": False}, general_params={ "random_seed": seed, "outcome_model": OutcomeModel(outcome="ordinal", link="cloglog"), diff --git a/debug/benchmark_cpp_vs_r_sampler_cloglog.R b/debug/benchmark_cpp_vs_r_sampler_cloglog.R index 47613dbb..9da63934 100644 --- a/debug/benchmark_cpp_vs_r_sampler_cloglog.R +++ b/debug/benchmark_cpp_vs_r_sampler_cloglog.R @@ -72,7 +72,10 @@ run_once <- function(run_cpp, num_gfr, num_mcmc, seed = -1) { num_gfr = num_gfr, num_burnin = 0, num_mcmc = num_mcmc, - mean_forest_params = list(num_trees = num_trees), + mean_forest_params = list( + num_trees = num_trees, + sample_sigma2_leaf = FALSE + ), general_params = list( random_seed = seed, outcome_model = OutcomeModel(outcome = "binary", link = "cloglog"), diff --git a/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R b/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R index 10938387..ad8236be 100644 --- a/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R +++ b/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R @@ -93,7 +93,10 @@ run_once <- function(run_cpp, num_gfr, num_mcmc, seed = -1) { num_gfr = num_gfr, num_burnin = 0, num_mcmc = num_mcmc, - mean_forest_params = list(num_trees = num_trees), + mean_forest_params = list( + num_trees = num_trees, + sample_sigma2_leaf = FALSE + ), general_params = list( random_seed = seed, outcome_model = OutcomeModel(outcome = "ordinal", link = "cloglog"), diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 2e832ba0..d086c58e 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -1411,6 +1411,15 @@ def test_cloglog_binary_bart_with_gfr(self): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Correctness: posterior-mean predicted probability must correlate with true + # P(Y=1|X). Residual corruption from reconstitute_from_forest (the GFR warm-start + # bug) produces near-random predictions that would fail this check. + p_true_test = prob[test_inds] + p_hat_mean = bart_model.predict( + X=X_test, type="mean", scale="probability", terms="y_hat" + ) + assert np.corrcoef(p_hat_mean, p_true_test)[0, 1] > 0.5 + def test_cloglog_ordinal_bart(self): # RNG random_seed = 101 @@ -1573,3 +1582,12 @@ def test_cloglog_ordinal_bart_with_gfr(self): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) assert bart_model.cloglog_cutpoint_samples.shape == (2, num_mcmc) + + # Correctness: predicted P(Y=1) must correlate with the true P(Y=1). + # Residual corruption from reconstitute_from_forest produces near-random + # predictions that would fail this check. + true_probs_test = true_probs[test_inds, :] + preds_mean_prob = bart_model.predict( + X=X_test, type="mean", scale="probability", terms="y_hat" + ) + assert np.corrcoef(preds_mean_prob[:, 0], true_probs_test[:, 0])[0, 1] > 0.3 From 1ffe7c703a1c677e8f47d90cd168bf1cbb8e934d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Apr 2026 11:26:54 -0500 Subject: [PATCH 60/64] Updated cloglog benchmark scripts --- debug/benchmark_cpp_vs_py_sampler_cloglog.py | 37 ++++++++------- ...hmark_cpp_vs_py_sampler_ordinal_cloglog.py | 42 +++++++++-------- debug/benchmark_cpp_vs_r_sampler_cloglog.R | 37 ++++++--------- ...nchmark_cpp_vs_r_sampler_ordinal_cloglog.R | 45 +++++++------------ 4 files changed, 70 insertions(+), 91 deletions(-) diff --git a/debug/benchmark_cpp_vs_py_sampler_cloglog.py b/debug/benchmark_cpp_vs_py_sampler_cloglog.py index 9d6eacb4..2286de19 100644 --- a/debug/benchmark_cpp_vs_py_sampler_cloglog.py +++ b/debug/benchmark_cpp_vs_py_sampler_cloglog.py @@ -4,8 +4,9 @@ run_cpp=True / False in BARTModel.sample(). DGP uses the cloglog link: P(Y=1|X) = 1 - exp(-exp(f(X))). -The step function for f(X) is kept in the range [-2, 1] so that the implied -probabilities span roughly 0.13 to 0.93 and are well-identified. +f(X) is a sum of smooth sinusoidal terms across two covariates, keeping +probabilities in the moderate range [0.25, 0.75] for stable mixing. +GFR is disabled (num_gfr=0) since it interacts poorly with cloglog binary. Usage: conda activate stochtree-book # or: source venv/bin/activate @@ -26,12 +27,12 @@ X = rng.uniform(size=(n, p)) # Latent mean on the cloglog (log-log) scale. -# P(Y=1|X) = 1 - exp(-exp(f_X)); values chosen so probabilities are moderate. +# f_X is centred near -0.5 so that P(Y=1|X) = 1 - exp(-exp(f_X)) stays +# in [~0.25, ~0.75], avoiding extreme probabilities that inflate GFR depth. f_X = ( - np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -2.0, 0) + - np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -0.5, 0) + - np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 0.5, 0) + - np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 1.0, 0) + 0.6 * np.sin(2 * np.pi * X[:, 0]) + + 0.4 * np.cos(2 * np.pi * X[:, 1]) + - 0.5 ) p_X = 1.0 - np.exp(-np.exp(f_X)) # true P(Y = 1 | X) y = rng.binomial(1, p_X).astype(float) # observed binary outcome @@ -49,20 +50,22 @@ # --------------------------------------------------------------------------- # Benchmark settings # --------------------------------------------------------------------------- -num_gfr = 10 -num_mcmc = 100 -num_trees = 200 -n_reps = 3 +num_gfr = 0 +num_burnin = 100 +num_mcmc = 100 +num_trees = 200 +n_reps = 3 print( f"n_train={n_train} n_test={n_test} p={p} " - f"num_trees={num_trees} num_gfr={num_gfr} num_mcmc={num_mcmc} reps={n_reps}\n" + f"num_trees={num_trees} num_gfr={num_gfr} num_burnin={num_burnin} " + f"num_mcmc={num_mcmc} reps={n_reps}\n" ) # --------------------------------------------------------------------------- # Helper: run one configuration and return timing + metrics # --------------------------------------------------------------------------- -def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: +def run_once(run_cpp: bool, seed: int) -> dict: m = BARTModel() t0 = time.perf_counter() m.sample( @@ -70,7 +73,7 @@ def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: y_train=y_train, X_test=X_test, num_gfr=num_gfr, - num_burnin=0, + num_burnin=num_burnin, num_mcmc=num_mcmc, mean_forest_params={"num_trees": num_trees, "sample_sigma2_leaf": False}, general_params={ @@ -94,7 +97,7 @@ def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: # --------------------------------------------------------------------------- # Run benchmarks # --------------------------------------------------------------------------- -seeds = [1000 + i for i in range(1, n_reps + 1)] +seeds = [100 + i for i in range(1, n_reps + 1)] results_cpp = [] results_py = [] @@ -102,12 +105,12 @@ def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: print("Running C++ sampler (run_cpp=True)...") for i, seed in enumerate(seeds, 1): print(f" rep {i}/{n_reps}") - results_cpp.append(run_once(run_cpp=True, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + results_cpp.append(run_once(run_cpp=True, seed=seed)) print("\nRunning Python sampler (run_cpp=False)...") for i, seed in enumerate(seeds, 1): print(f" rep {i}/{n_reps}") - results_py.append(run_once(run_cpp=False, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + results_py.append(run_once(run_cpp=False, seed=seed)) # --------------------------------------------------------------------------- # Summarise diff --git a/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py b/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py index e6c568ed..e163cc3e 100644 --- a/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py +++ b/debug/benchmark_cpp_vs_py_sampler_ordinal_cloglog.py @@ -4,9 +4,9 @@ probabilities) across run_cpp=True / False in BARTModel.sample(). DGP uses 4 ordinal categories with a cloglog link. -The latent step function f(X) is on the log-log scale, and each category -boundary (gamma_k) is fixed at log(k) for k = 1, 2, 3 so the four -cumulative probabilities are P(Y <= k | X) = 1 - exp(-exp(f(X) - gamma_k)). +f(X) is a smooth sinusoidal function of two covariates so that all four +categories are well-populated. Cutpoints are spaced to yield roughly equal +marginal class frequencies. GFR is disabled (num_gfr=0). Usage: conda activate stochtree-book # or: source venv/bin/activate @@ -26,18 +26,14 @@ p = 10 X = rng.uniform(size=(n, p)) -# Latent step function on the cloglog scale -f_X = ( - np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -2.0, 0.0) + - np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -0.5, 0.0) + - np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 0.5, 0.0) + - np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 1.0, 0.0) -) +# Latent mean on the cloglog scale (smooth, two covariates) +f_X = 0.6 * np.sin(2 * np.pi * X[:, 0]) + 0.4 * np.cos(2 * np.pi * X[:, 1]) -# Fixed log-scale cutpoints (gamma_k); K = 4 categories => K-1 = 3 cutpoints -# gamma_0 is fixed at 0 for identifiability; gamma_1 = log(2), gamma_2 = log(3) +# Fixed log-scale cutpoints spaced to give roughly equal marginal class freqs. +# With f_X in roughly [-1, 1], gamma = [0, log(2), log(4)] puts the four +# cumulative boundaries at moderate probability levels. K = 4 -gamma_true = np.array([0.0, np.log(2), np.log(3)]) +gamma_true = np.array([0.0, np.log(2), np.log(4)]) # True cumulative probabilities: P(Y <= k | X) = 1 - exp(-exp(f_X - gamma_k)) # Shape: (n, K-1) @@ -72,20 +68,22 @@ # --------------------------------------------------------------------------- # Benchmark settings # --------------------------------------------------------------------------- -num_gfr = 10 -num_mcmc = 100 -num_trees = 200 -n_reps = 3 +num_gfr = 0 +num_burnin = 100 +num_mcmc = 100 +num_trees = 200 +n_reps = 3 print( f"K={K} n_train={n_train} n_test={n_test} p={p} " - f"num_trees={num_trees} num_gfr={num_gfr} num_mcmc={num_mcmc} reps={n_reps}\n" + f"num_trees={num_trees} num_gfr={num_gfr} num_burnin={num_burnin} " + f"num_mcmc={num_mcmc} reps={n_reps}\n" ) # --------------------------------------------------------------------------- # Helper: run one configuration and return timing + metrics # --------------------------------------------------------------------------- -def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: +def run_once(run_cpp: bool, seed: int) -> dict: m = BARTModel() t0 = time.perf_counter() m.sample( @@ -93,7 +91,7 @@ def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: y_train=y_train, X_test=X_test, num_gfr=num_gfr, - num_burnin=0, + num_burnin=num_burnin, num_mcmc=num_mcmc, mean_forest_params={"num_trees": num_trees, "sample_sigma2_leaf": False}, general_params={ @@ -128,12 +126,12 @@ def run_once(run_cpp: bool, num_gfr: int, num_mcmc: int, seed: int) -> dict: print("Running C++ sampler (run_cpp=True)...") for i, seed in enumerate(seeds, 1): print(f" rep {i}/{n_reps}") - results_cpp.append(run_once(run_cpp=True, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + results_cpp.append(run_once(run_cpp=True, seed=seed)) print("\nRunning Python sampler (run_cpp=False)...") for i, seed in enumerate(seeds, 1): print(f" rep {i}/{n_reps}") - results_py.append(run_once(run_cpp=False, num_gfr=num_gfr, num_mcmc=num_mcmc, seed=seed)) + results_py.append(run_once(run_cpp=False, seed=seed)) # --------------------------------------------------------------------------- # Summarise diff --git a/debug/benchmark_cpp_vs_r_sampler_cloglog.R b/debug/benchmark_cpp_vs_r_sampler_cloglog.R index 9da63934..16a2537a 100644 --- a/debug/benchmark_cpp_vs_r_sampler_cloglog.R +++ b/debug/benchmark_cpp_vs_r_sampler_cloglog.R @@ -3,8 +3,8 @@ ## run_cpp = TRUE / FALSE in bart(). ## ## DGP uses the cloglog link: P(Y=1|X) = 1 - exp(-exp(f(X))). -## The step function for f(X) is kept in the range [-2, 1] so that the implied -## probabilities span roughly 0.13 to 0.93 and are well-identified. +## f(X) is a smooth sinusoidal function of two covariates, keeping probabilities +## in [~0.25, ~0.75] for stable mixing. GFR is disabled (num_gfr = 0). ## ## Usage: Rscript debug/benchmark_cpp_vs_r_sampler_cloglog.R ## or source() from an interactive session after devtools::load_all('.') @@ -20,12 +20,9 @@ p <- 10 X <- matrix(runif(n * p), ncol = p) # Latent mean on the cloglog (log-log) scale. -# P(Y=1|X) = 1 - exp(-exp(f_X)); values chosen so probabilities are moderate. -f_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * - (-2.0) + - ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (-0.5) + - ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (0.5) + - ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (1.0)) +# f_X is centred near -0.5 so that P(Y=1|X) = 1 - exp(-exp(f_X)) stays +# in [~0.25, ~0.75], avoiding extreme probabilities that inflate tree depth. +f_X <- 0.6 * sin(2 * pi * X[, 1]) + 0.4 * cos(2 * pi * X[, 2]) - 0.5 p_X <- 1 - exp(-exp(f_X)) # true P(Y = 1 | X) y <- rbinom(n, 1L, p_X) # observed binary outcome @@ -44,18 +41,20 @@ p_test <- p_X[test_inds] # --------------------------------------------------------------------------- # Benchmark settings # --------------------------------------------------------------------------- -num_gfr <- 10 +num_gfr <- 0 +num_burnin <- 100 num_mcmc <- 100 num_trees <- 200 n_reps <- 3 cat(sprintf( - "n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_mcmc=%d reps=%d\n\n", + "n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_burnin=%d num_mcmc=%d reps=%d\n\n", n_train, n_test, p, num_trees, num_gfr, + num_burnin, num_mcmc, n_reps )) @@ -63,14 +62,14 @@ cat(sprintf( # --------------------------------------------------------------------------- # Helper: run one configuration and return timing + metrics # --------------------------------------------------------------------------- -run_once <- function(run_cpp, num_gfr, num_mcmc, seed = -1) { +run_once <- function(run_cpp, seed = -1) { t0 <- proc.time() m <- bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, - num_burnin = 0, + num_burnin = num_burnin, num_mcmc = num_mcmc, mean_forest_params = list( num_trees = num_trees, @@ -115,23 +114,13 @@ results_r <- vector("list", n_reps) cat("Running C++ sampler (run_cpp = TRUE)...\n") for (i in seq_len(n_reps)) { cat(sprintf(" rep %d/%d\n", i, n_reps)) - results_cpp[[i]] <- run_once( - run_cpp = TRUE, - num_gfr = num_gfr, - num_mcmc = num_mcmc, - seed = seeds[i] - ) + results_cpp[[i]] <- run_once(run_cpp = TRUE, seed = seeds[i]) } cat("\nRunning R sampler (run_cpp = FALSE)...\n") for (i in seq_len(n_reps)) { cat(sprintf(" rep %d/%d\n", i, n_reps)) - results_r[[i]] <- run_once( - run_cpp = FALSE, - num_gfr = num_gfr, - num_mcmc = num_mcmc, - seed = seeds[i] - ) + results_r[[i]] <- run_once(run_cpp = FALSE, seed = seeds[i]) } # --------------------------------------------------------------------------- diff --git a/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R b/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R index ad8236be..3088737b 100644 --- a/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R +++ b/debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R @@ -3,9 +3,9 @@ ## probabilities) across run_cpp = TRUE / FALSE in bart(). ## ## DGP uses 4 ordinal categories with a cloglog link. -## The latent step function f(X) is on the log-log scale, and each category -## boundary (gamma_k) is fixed at log(k) for k = 1, 2, 3 so the four -## cumulative probabilities are P(Y <= k | X) = 1 - exp(-exp(f(X) - gamma_k)). +## f(X) is a smooth sinusoidal function of two covariates so that all four +## categories are well-populated. Cutpoints are spaced to yield roughly equal +## marginal class frequencies. GFR is disabled (num_gfr = 0). ## ## Usage: Rscript debug/benchmark_cpp_vs_r_sampler_ordinal_cloglog.R ## or source() from an interactive session after devtools::load_all('.') @@ -20,17 +20,14 @@ n <- 2000 p <- 10 X <- matrix(runif(n * p), ncol = p) -# Latent step function on the cloglog scale -f_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * - (-2.0) + - ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (-0.5) + - ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (0.5) + - ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (1.0)) +# Latent mean on the cloglog scale (smooth, two covariates) +f_X <- 0.6 * sin(2 * pi * X[, 1]) + 0.4 * cos(2 * pi * X[, 2]) -# Fixed log-scale cutpoints (gamma_k); K = 4 categories => K-1 = 3 cutpoints -# gamma_0 is fixed at 0 for identifiability; gamma_1 = log(2), gamma_2 = log(3) +# Fixed log-scale cutpoints spaced to give roughly equal marginal class freqs. +# With f_X in roughly [-1, 1], setting gamma = c(0, log(2), log(4)) puts the +# four cumulative boundaries at moderate probability levels. K <- 4 -gamma_true <- c(0, log(2), log(3)) +gamma_true <- c(0, log(2), log(4)) # True cumulative probabilities: P(Y <= k | X) = 1 - exp(-exp(f_X - gamma_k)) # True class probabilities: P(Y = k | X) = P(Y <= k) - P(Y <= k-1) @@ -64,19 +61,21 @@ p_test <- p_X[test_inds, ] # n_test x K matrix of true class probabilities # --------------------------------------------------------------------------- # Benchmark settings # --------------------------------------------------------------------------- -num_gfr <- 10 +num_gfr <- 0 +num_burnin <- 100 num_mcmc <- 100 num_trees <- 200 n_reps <- 3 cat(sprintf( - "K=%d n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_mcmc=%d reps=%d\n\n", + "K=%d n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_burnin=%d num_mcmc=%d reps=%d\n\n", K, n_train, n_test, p, num_trees, num_gfr, + num_burnin, num_mcmc, n_reps )) @@ -84,14 +83,14 @@ cat(sprintf( # --------------------------------------------------------------------------- # Helper: run one configuration and return timing + metrics # --------------------------------------------------------------------------- -run_once <- function(run_cpp, num_gfr, num_mcmc, seed = -1) { +run_once <- function(run_cpp, seed = -1) { t0 <- proc.time() m <- bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, - num_burnin = 0, + num_burnin = num_burnin, num_mcmc = num_mcmc, mean_forest_params = list( num_trees = num_trees, @@ -137,23 +136,13 @@ results_r <- vector("list", n_reps) cat("Running C++ sampler (run_cpp = TRUE)...\n") for (i in seq_len(n_reps)) { cat(sprintf(" rep %d/%d\n", i, n_reps)) - results_cpp[[i]] <- run_once( - run_cpp = TRUE, - num_gfr = num_gfr, - num_mcmc = num_mcmc, - seed = seeds[i] - ) + results_cpp[[i]] <- run_once(run_cpp = TRUE, seed = seeds[i]) } cat("\nRunning R sampler (run_cpp = FALSE)...\n") for (i in seq_len(n_reps)) { cat(sprintf(" rep %d/%d\n", i, n_reps)) - results_r[[i]] <- run_once( - run_cpp = FALSE, - num_gfr = num_gfr, - num_mcmc = num_mcmc, - seed = seeds[i] - ) + results_r[[i]] <- run_once(run_cpp = FALSE, seed = seeds[i]) } # --------------------------------------------------------------------------- From 107d8870ba263774ac437f0294e5c6950e55fdf6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Apr 2026 14:15:32 -0500 Subject: [PATCH 61/64] Support heteroskedasticity with no mean model and add errors when case weights and variance forest are both included --- R/bart.R | 5 +- R/bcf.R | 5 +- ...hmark_cpp_vs_py_sampler_heteroskedastic.py | 161 ++++++++++++++++++ src/bart_sampler.cpp | 57 +++++-- stochtree/bart.py | 28 +-- stochtree/bcf.py | 5 +- 6 files changed, 226 insertions(+), 35 deletions(-) create mode 100644 debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py diff --git a/R/bart.R b/R/bart.R index 11bdfb01..b1e2af8d 100644 --- a/R/bart.R +++ b/R/bart.R @@ -511,8 +511,9 @@ bart <- function( ) } if (include_variance_forest) { - warning( - "Results may be unreliable when observation_weights are deployed alongside a variance forest model." + stop( + "observation_weights are not compatible with a variance forest model. ", + "Use either observation_weights or a variance forest, not both." ) } } diff --git a/R/bcf.R b/R/bcf.R index 0ab76098..d3a5b166 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -653,8 +653,9 @@ bcf <- function( ) } if (include_variance_forest) { - warning( - "Results may be unreliable when observation_weights are deployed alongside a variance forest model." + stop( + "observation_weights are not compatible with a variance forest model. ", + "Use either observation_weights or a variance forest, not both." ) } } diff --git a/debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py b/debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py new file mode 100644 index 00000000..ac9cc9f9 --- /dev/null +++ b/debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py @@ -0,0 +1,161 @@ +"""Benchmark: C++ sampler loop vs. Python sampler loop – heteroskedastic BART. + +Compares runtime, mean-forest RMSE (vs. true f(X)) and RMSE of the estimated +conditional standard deviation (vs. the true s(X)) across run_cpp=True / +False in BARTModel.sample() with both a mean forest and a variance forest +(num_trees_variance > 0). + +DGP: f(X) is a step function of X[:,0]; s(X) varies by quadrant of X[:,0] +and linearly with X[:,2], matching the heteroskedastic_bart.R debug script. + +Note: A variance-only model (num_trees_mean=0, num_trees_variance>0) is now +supported in the C++ path. The mean-forest RMSE is reported as NaN in that +case since there is no mean forest to evaluate. + +Usage: + source venv/bin/activate + python debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py +""" + +import time +import numpy as np +from stochtree import BARTModel + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +rng = np.random.default_rng(1234) + +n = 2000 +p = 10 +X = rng.uniform(size=(n, p)) + +# True conditional mean and conditional std dev +f_X = ( + np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -3.0, 0) + + np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -1.0, 0) + + np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 1.0, 0) + + np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 3.0, 0) +) +s_X = ( + np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), 0.5 * X[:, 2], 0) + + np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), 1.0 * X[:, 2], 0) + + np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 2.0 * X[:, 2], 0) + + np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 3.0 * X[:, 2], 0) +) +y = f_X + rng.standard_normal(n) * s_X + +test_frac = 0.2 +n_test = round(test_frac * n) +n_train = n - n_test +test_inds = rng.choice(n, size=n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) + +X_train, X_test = X[train_inds], X[test_inds] +y_train, y_test = y[train_inds], y[test_inds] +f_test = f_X[test_inds] +s_test = s_X[test_inds] + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 +num_trees_mean = 0 +num_trees_variance = 50 +n_reps = 3 + +print( + f"n_train={n_train} n_test={n_test} p={p} " + f"num_trees_mean={num_trees_mean} num_trees_variance={num_trees_variance} " + f"num_gfr={num_gfr} num_burnin={num_burnin} num_mcmc={num_mcmc} reps={n_reps}\n" +) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + metrics +# --------------------------------------------------------------------------- +def run_once(run_cpp: bool, seed: int) -> dict: + m = BARTModel() + t0 = time.perf_counter() + m.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params={ + "random_seed": seed, + "sample_sigma2_global": False, + }, + mean_forest_params={"num_trees": num_trees_mean}, + variance_forest_params={"num_trees": num_trees_variance}, + run_cpp=run_cpp, + ) + elapsed = time.perf_counter() - t0 + + # mean-forest RMSE vs. true f(X) – only defined when a mean forest was fitted + if num_trees_mean > 0: + f_hat = m.y_hat_test.mean(axis=1) + rmse_f = float(np.sqrt(np.mean((f_hat - f_test) ** 2))) + else: + rmse_f = float("nan") + # sigma2_x_test has shape (n_test, num_mcmc); take posterior mean of cond. std dev + s_hat = np.sqrt(m.sigma2_x_test).mean(axis=1) + rmse_s = float(np.sqrt(np.mean((s_hat - s_test) ** 2))) + return {"elapsed": elapsed, "rmse_f": rmse_f, "rmse_s": rmse_s} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds = [1000 + i for i in range(1, n_reps + 1)] + +results_cpp = [] +results_py = [] + +print("Running C++ sampler (run_cpp=True)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_cpp.append(run_once(run_cpp=True, seed=seed)) + +print("\nRunning Python sampler (run_cpp=False)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_py.append(run_once(run_cpp=False, seed=seed)) + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +def summarise(results: list) -> dict: + elapsed = [r["elapsed"] for r in results] + rmse_f = [r["rmse_f"] for r in results] + rmse_s = [r["rmse_s"] for r in results] + return { + "elapsed_mean": float(np.mean(elapsed)), + "elapsed_sd": float(np.std(elapsed, ddof=1)), + "rmse_f_mean": float(np.nanmean(rmse_f)), + "rmse_s_mean": float(np.mean(rmse_s)), + } + +s_cpp = summarise(results_cpp) +s_py = summarise(results_py) +rows = [("cpp (run_cpp=True)", s_cpp), ("py (run_cpp=False)", s_py)] + +print("\n--- Results ---") +print( + f"{'Sampler':<22} {'Time (s)':>10} {'SD':>10} {'RMSE f(X)':>12} {'RMSE s(X)':>12}" +) +print("-" * 74) +for label, s in rows: + print( + f"{label:<22} {s['elapsed_mean']:>10.3f} {s['elapsed_sd']:>10.3f}" + f" {s['rmse_f_mean']:>12.4f} {s['rmse_s_mean']:>12.4f}" + ) + +speedup = s_py["elapsed_mean"] / s_cpp["elapsed_mean"] +print(f"\nSpeedup (py / cpp): {speedup:.2f}x") +print( + f"RMSE f(X) delta (cpp - py): {s_cpp['rmse_f_mean'] - s_py['rmse_f_mean']:.4f}\n" + f"RMSE s(X) delta (cpp - py): {s_cpp['rmse_s_mean'] - s_py['rmse_s_mean']:.4f}" +) diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index bf00045a..442a5d6d 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -69,22 +69,8 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } double y_var = M2 / data_.n_train; - // Standardization and calibration for mean forests + // Standardization, calibration, and initialization for mean forests if (config_.num_trees_mean > 0) { - // Initialize leaf model - if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianConstant) { - mean_leaf_model_ = GaussianConstantLeafModel(config_.sigma2_mean_init); - } else if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianUnivariateRegression) { - mean_leaf_model_ = GaussianUnivariateRegressionLeafModel(config_.sigma2_mean_init); - } else if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianMultivariateRegression) { - // TODO - // mean_leaf_model_ = GaussianMultivariateRegressionLeafModel(...); - } else if (config_.mean_leaf_model_type == MeanLeafModelType::CloglogOrdinal) { - mean_leaf_model_ = CloglogOrdinalLeafModel(config_.cloglog_leaf_prior_shape, config_.cloglog_leaf_prior_scale); - } else { - Log::Fatal("Unsupported leaf model type for mean forest"); - } - if (config_.link_function == LinkFunction::Probit) { // Initialize forests to 0, no scaling, but offset by the probit transform of the mean outcome to improve mixing samples.y_std = 1.0; @@ -116,6 +102,8 @@ void BARTSampler::InitializeState(BARTSamples& samples) { // TODO ... } } + + // Calibrate leaf scale and variance model priors if (config_.sigma2_mean_init < 0.0) { if (config_.link_function == LinkFunction::Probit) { config_.sigma2_mean_init = 1.0 / config_.num_trees_mean; @@ -132,6 +120,30 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } } } + + // Initialize leaf model + if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianConstant) { + mean_leaf_model_ = GaussianConstantLeafModel(config_.sigma2_mean_init); + } else if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianUnivariateRegression) { + mean_leaf_model_ = GaussianUnivariateRegressionLeafModel(config_.sigma2_mean_init); + } else if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianMultivariateRegression) { + // TODO + // mean_leaf_model_ = GaussianMultivariateRegressionLeafModel(...); + } else if (config_.mean_leaf_model_type == MeanLeafModelType::CloglogOrdinal) { + mean_leaf_model_ = CloglogOrdinalLeafModel(config_.cloglog_leaf_prior_shape, config_.cloglog_leaf_prior_scale); + } else { + Log::Fatal("Unsupported leaf model type for mean forest"); + } + } else { + // Variance-only model (num_trees_mean == 0): no mean forest, but y_bar/y_std must + // still be valid so the residual initialisation below doesn't divide by zero. + if (config_.standardize_outcome) { + samples.y_bar = y_mean; + samples.y_std = std::sqrt(y_var); + } else { + samples.y_bar = 0.0; + samples.y_std = 1.0; + } } // Calibration for variance forests @@ -174,8 +186,21 @@ void BARTSampler::InitializeState(BARTSamples& samples) { samples.variance_forests = std::make_unique(config_.num_trees_variance, config_.leaf_dim_variance, config_.leaf_constant_variance, config_.exponentiated_leaf_variance); variance_forest_tracker_ = std::make_unique(forest_dataset_->GetCovariates(), config_.feature_types, config_.num_trees_variance, data_.n_train); tree_prior_variance_ = std::make_unique(config_.alpha_variance, config_.beta_variance, config_.min_samples_leaf_variance, config_.max_depth_variance); - variance_forest_->SetLeafValue(init_val_variance_ / config_.num_trees_variance); + // Leaf values for the log-linear variance model are on the log scale; the ensemble sums + // log(sigma^2_i) contributions, so each tree starts at log(init_val) / num_trees. + variance_forest_->SetLeafValue(std::log(init_val_variance_) / config_.num_trees_variance); variance_forest_tracker_->UpdatePredictions(variance_forest_.get(), *forest_dataset_.get()); + // UpdateVarModelTree (called inside GFRSampleOneIter / MCMCSampleOneIter) unconditionally + // reads and writes the dataset variance weight slot via VarWeightValue / SetVarWeightValue. + // This slot tracks the cumulative per-observation variance prediction + // (sigma^2_i = exp(sum of tree leaf values)) and is incompatible with case weights, which + // would need to be reapplied after every per-tree update. The R/Python APIs enforce this + // as a hard error; guard here for callers that use BARTSampler directly. + if (forest_dataset_->HasVarWeights()) { + Log::Fatal("observation_weights and a variance forest cannot be used together."); + } + std::vector initial_variance_preds(data_.n_train, init_val_variance_); + forest_dataset_->AddVarianceWeights(initial_variance_preds.data(), data_.n_train); has_variance_forest_ = true; } diff --git a/stochtree/bart.py b/stochtree/bart.py index 3690505a..bdcf4242 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -451,8 +451,9 @@ def sample( "observation_weights are not compatible with cloglog link functions." ) if self.include_variance_forest: - warnings.warn( - "Results may be unreliable when observation_weights are deployed alongside a variance forest model." + raise ValueError( + "observation_weights are not compatible with a variance forest model." + "Use either observation_weights or a variance forest, not both." ) # Check data inputs @@ -1237,17 +1238,18 @@ def sample( self.scale_variance_forest = bart_results["scale_variance_forest"] if self.include_variance_forest else None # Unpack mean forest results - self.forest_container_mean = ( - ForestContainer(num_trees_mean, 1, True, False) - if not self.has_basis - else ForestContainer(num_trees_mean, self.num_basis, False, False) - ) - self.forest_container_mean.forest_container_cpp = bart_results["forest_container_mean"] - mean_forest_preds_train = bart_results["mean_forest_predictions_train"].reshape(self.n_train, bart_results["num_samples"], order="F") - self.y_hat_train = mean_forest_preds_train * self.y_std + self.y_bar - if self.has_test: - mean_forest_preds_test = bart_results["mean_forest_predictions_test"].reshape(self.n_test, bart_results["num_samples"], order="F") - self.y_hat_test = mean_forest_preds_test * self.y_std + self.y_bar + if self.include_mean_forest: + self.forest_container_mean = ( + ForestContainer(num_trees_mean, 1, True, False) + if not self.has_basis + else ForestContainer(num_trees_mean, self.num_basis, False, False) + ) + self.forest_container_mean.forest_container_cpp = bart_results["forest_container_mean"] + mean_forest_preds_train = bart_results["mean_forest_predictions_train"].reshape(self.n_train, bart_results["num_samples"], order="F") + self.y_hat_train = mean_forest_preds_train * self.y_std + self.y_bar + if self.has_test: + mean_forest_preds_test = bart_results["mean_forest_predictions_test"].reshape(self.n_test, bart_results["num_samples"], order="F") + self.y_hat_test = mean_forest_preds_test * self.y_std + self.y_bar # Unpack variance forest results if self.include_variance_forest: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index ac04bda9..f6d124ff 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -647,8 +647,9 @@ def sample( "observation_weights are not compatible with cloglog link functions." ) if self.include_variance_forest: - warnings.warn( - "Results may be unreliable when observation_weights are deployed alongside a variance forest model." + raise ValueError( + "observation_weights are not compatible with a variance forest model. " + "Use either observation_weights or a variance forest, not both." ) # Check data inputs From 17379b319aa6f19d05ef7f160bbd55bf7a789bbe Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Apr 2026 14:55:42 -0500 Subject: [PATCH 62/64] Fix heteroskedasticity-only code path and add heteroskedasticity benchmarking R script --- R/bart.R | 10 +- ...hmark_cpp_vs_py_sampler_heteroskedastic.py | 2 +- ...nchmark_cpp_vs_r_sampler_heteroskedastic.R | 189 ++++++++++++++++++ src/bart_sampler.cpp | 2 +- stochtree/bart.py | 1 + 5 files changed, 200 insertions(+), 4 deletions(-) create mode 100644 debug/benchmark_cpp_vs_r_sampler_heteroskedastic.R diff --git a/R/bart.R b/R/bart.R index b1e2af8d..c2f5e64a 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1321,18 +1321,24 @@ bart <- function( bart_results[["num_train"]], bart_results[["num_samples"]] ) + y_std_cpp <- bart_results[["y_std"]] result[["sigma2_x_hat_train"]] <- bart_results[[ "variance_forest_predictions_train" - ]] + ]] * + y_std_cpp * + y_std_cpp } if (has_variance_forest_predictions_test) { dim(bart_results[['variance_forest_predictions_test']]) <- c( bart_results[["num_test"]], bart_results[["num_samples"]] ) + y_std_cpp <- bart_results[["y_std"]] result[["sigma2_x_hat_test"]] <- bart_results[[ "variance_forest_predictions_test" - ]] + ]] * + y_std_cpp * + y_std_cpp } if ( has_variance_forest_predictions_train || diff --git a/debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py b/debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py index ac9cc9f9..70a50342 100644 --- a/debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py +++ b/debug/benchmark_cpp_vs_py_sampler_heteroskedastic.py @@ -62,7 +62,7 @@ num_gfr = 10 num_burnin = 0 num_mcmc = 100 -num_trees_mean = 0 +num_trees_mean = 200 num_trees_variance = 50 n_reps = 3 diff --git a/debug/benchmark_cpp_vs_r_sampler_heteroskedastic.R b/debug/benchmark_cpp_vs_r_sampler_heteroskedastic.R new file mode 100644 index 00000000..0600a418 --- /dev/null +++ b/debug/benchmark_cpp_vs_r_sampler_heteroskedastic.R @@ -0,0 +1,189 @@ +## Benchmark: C++ sampler loop vs. R sampler loop -- heteroskedastic BART. +## +## Compares runtime, mean-forest RMSE (vs. true f(X)) and RMSE of the estimated +## conditional standard deviation (vs. the true s(X)) across run_cpp = TRUE / +## FALSE in bart() with both a mean forest and a variance forest +## (num_trees_variance > 0). +## +## DGP: f(X) is a step function of X[,1]; s(X) varies by quadrant of X[,1] +## and linearly with X[,3], matching the heteroskedastic Python benchmark. +## +## A variance-only model (num_trees_mean = 0, num_trees_variance > 0) is +## supported in the C++ path. The mean-forest RMSE is reported as NA in +## that case since there is no mean forest to evaluate. +## +## Usage: +## Rscript debug/benchmark_cpp_vs_r_sampler_heteroskedastic.R +## or source() from an interactive session after devtools::load_all('.') + +library(stochtree) + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +set.seed(1234) + +n <- 2000 +p <- 10 +X <- matrix(runif(n * p), ncol = p) + +# True conditional mean +f_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * + (-3.0) + + ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (-1.0) + + ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (1.0) + + ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (3.0)) + +# True conditional standard deviation +s_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * + (0.5 * X[, 3]) + + ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (1.0 * X[, 3]) + + ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (2.0 * X[, 3]) + + ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (3.0 * X[, 3])) + +y <- f_X + rnorm(n, 0, 1) * s_X + +test_frac <- 0.2 +n_test <- round(test_frac * n) +n_train <- n - n_test +test_inds <- sort(sample(seq_len(n), n_test, replace = FALSE)) +train_inds <- setdiff(seq_len(n), test_inds) + +X_train <- X[train_inds, ] +X_test <- X[test_inds, ] +y_train <- y[train_inds] +y_test <- y[test_inds] +f_test <- f_X[test_inds] +s_test <- s_X[test_inds] + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_trees_mean <- 200 +num_trees_variance <- 50 +n_reps <- 3 + +cat(sprintf( + "n_train=%d n_test=%d p=%d num_trees_mean=%d num_trees_variance=%d num_gfr=%d num_burnin=%d num_mcmc=%d reps=%d\n\n", + n_train, + n_test, + p, + num_trees_mean, + num_trees_variance, + num_gfr, + num_burnin, + num_mcmc, + n_reps +)) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + metrics +# --------------------------------------------------------------------------- +run_once <- function(run_cpp, seed) { + t0 <- proc.time() + m <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = list(random_seed = seed, sample_sigma2_global = FALSE), + mean_forest_params = list(num_trees = num_trees_mean), + variance_forest_params = list(num_trees = num_trees_variance), + run_cpp = run_cpp + ) + elapsed <- (proc.time() - t0)[["elapsed"]] + + # Mean-forest RMSE -- only defined when a mean forest was fitted + if (num_trees_mean > 0) { + f_hat <- rowMeans(m$y_hat_test) + rmse_f <- sqrt(mean((f_hat - f_test)^2)) + } else { + rmse_f <- NA_real_ + } + + # Variance-forest RMSE of estimated conditional std dev vs. true s(X) + sigma2_x_hat_test <- extractParameter(m, "sigma2_x_test") + s_hat <- rowMeans(sqrt(sigma2_x_hat_test)) + rmse_s <- sqrt(mean((s_hat - s_test)^2)) + + list(elapsed = elapsed, rmse_f = rmse_f, rmse_s = rmse_s) +} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds <- 1000 + seq_len(n_reps) + +results_cpp <- vector("list", n_reps) +results_r <- vector("list", n_reps) + +cat("Running C++ sampler (run_cpp = TRUE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_cpp[[i]] <- run_once(run_cpp = TRUE, seed = seeds[i]) +} + +cat("\nRunning R sampler (run_cpp = FALSE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_r[[i]] <- run_once(run_cpp = FALSE, seed = seeds[i]) +} + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +summarise <- function(results, label) { + elapsed <- sapply(results, `[[`, "elapsed") + rmse_f <- sapply(results, `[[`, "rmse_f") + rmse_s <- sapply(results, `[[`, "rmse_s") + data.frame( + sampler = label, + elapsed_mean = mean(elapsed), + elapsed_sd = sd(elapsed), + rmse_f_mean = mean(rmse_f, na.rm = TRUE), + rmse_s_mean = mean(rmse_s), + row.names = NULL + ) +} + +res <- rbind( + summarise(results_cpp, "cpp (run_cpp=TRUE)"), + summarise(results_r, "R (run_cpp=FALSE)") +) + +cat("\n--- Results ---\n") +cat(sprintf( + "%-22s %10s %10s %12s %12s\n", + "Sampler", + "Time (s)", + "SD", + "RMSE f(X)", + "RMSE s(X)" +)) +cat(strrep("-", 74), "\n") +for (i in seq_len(nrow(res))) { + cat(sprintf( + "%-22s %10.3f %10.3f %12s %12.4f\n", + res$sampler[i], + res$elapsed_mean[i], + res$elapsed_sd[i], + if (is.nan(res$rmse_f_mean[i])) { + "nan" + } else { + sprintf("%.4f", res$rmse_f_mean[i]) + }, + res$rmse_s_mean[i] + )) +} + +speedup <- res$elapsed_mean[2] / res$elapsed_mean[1] +cat(sprintf("\nSpeedup (R / C++): %.2fx\n", speedup)) +cat(sprintf( + "RMSE s(X) delta (cpp - R): %.4f\n", + res$rmse_s_mean[1] - res$rmse_s_mean[2] +)) diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 442a5d6d..0b014978 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -211,7 +211,7 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } // Leaf scale model - if (config_.sample_sigma2_leaf_mean) { + if (config_.sample_sigma2_leaf_mean && config_.num_trees_mean > 0) { leaf_scale_model_ = std::make_unique(); sample_sigma2_leaf_ = true; } diff --git a/stochtree/bart.py b/stochtree/bart.py index bdcf4242..75eef445 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1000,6 +1000,7 @@ def sample( # Preliminary runtime checks for probit link if not self.include_mean_forest: link_is_probit = False + sample_sigma2_leaf = False if link_is_probit: if np.unique(y_train).size != 2: raise ValueError( From 182c0a5cda0218b4a2419f0b524a3e38b60a79ff Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Apr 2026 16:02:37 -0500 Subject: [PATCH 63/64] Adding initialization routines for univariate leaf regression --- include/stochtree/bart_sampler.h | 9 ++++++++- src/bart_sampler.cpp | 12 +++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/include/stochtree/bart_sampler.h b/include/stochtree/bart_sampler.h index 771e81c2..170ddec6 100644 --- a/include/stochtree/bart_sampler.h +++ b/include/stochtree/bart_sampler.h @@ -59,7 +59,14 @@ class BARTSampler { sampler.has_mean_forest_ = true; } void operator()(GaussianUnivariateRegressionLeafModel& model) { - // TODO ... + sampler.mean_forest_ = std::make_unique(sampler.config_.num_trees_mean, sampler.config_.leaf_dim_mean, sampler.config_.leaf_constant_mean, sampler.config_.exponentiated_leaf_mean); + samples.mean_forests = std::make_unique(sampler.config_.num_trees_mean, sampler.config_.leaf_dim_mean, sampler.config_.leaf_constant_mean, sampler.config_.exponentiated_leaf_mean); + sampler.mean_forest_tracker_ = std::make_unique(sampler.forest_dataset_->GetCovariates(), sampler.config_.feature_types, sampler.config_.num_trees_mean, sampler.data_.n_train); + sampler.tree_prior_mean_ = std::make_unique(sampler.config_.alpha_mean, sampler.config_.beta_mean, sampler.config_.min_samples_leaf_mean, sampler.config_.max_depth_mean); + sampler.mean_forest_->SetLeafValue(sampler.init_val_mean_ / sampler.config_.num_trees_mean); + UpdateResidualEntireForest(*sampler.mean_forest_tracker_, *sampler.forest_dataset_, *sampler.residual_, sampler.mean_forest_.get(), !sampler.config_.leaf_constant_mean, std::minus()); + sampler.mean_forest_tracker_->UpdatePredictions(sampler.mean_forest_.get(), *sampler.forest_dataset_.get()); + sampler.has_mean_forest_ = true; } void operator()(GaussianMultivariateRegressionLeafModel& model) { // TODO ... diff --git a/src/bart_sampler.cpp b/src/bart_sampler.cpp index 0b014978..323ec8d9 100644 --- a/src/bart_sampler.cpp +++ b/src/bart_sampler.cpp @@ -96,7 +96,17 @@ void BARTSampler::InitializeState(BARTSamples& samples) { } } else if (config_.mean_leaf_model_type == MeanLeafModelType::GaussianUnivariateRegression) { // Case 2: Univariate leaf regression - // TODO ... + if (config_.standardize_outcome) { + samples.y_bar = y_mean; + samples.y_std = std::sqrt(y_var); + } else { + samples.y_bar = 0.0; + samples.y_std = 1.0; + } + // Always map initial leaf value to zero + // Users fitting a univariate leaf regression (with a non-centered basis) should standardize their outcomes + // TODO: consider adding warning in R / Python if univariate regression leaf model is specified without standardization + init_val_mean_ = 0.0; } else { // Case 3: Multivariate leaf regression // TODO ... From db1c823a3dac75c8771d6f8036b19d9e5fefe9ce Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Apr 2026 16:13:08 -0500 Subject: [PATCH 64/64] Added benchmarking scripts for univariate leaf regression --- ...hmark_cpp_vs_py_sampler_leaf_regression.py | 149 +++++++++++++++ ...nchmark_cpp_vs_r_sampler_leaf_regression.R | 171 ++++++++++++++++++ 2 files changed, 320 insertions(+) create mode 100644 debug/benchmark_cpp_vs_py_sampler_leaf_regression.py create mode 100644 debug/benchmark_cpp_vs_r_sampler_leaf_regression.R diff --git a/debug/benchmark_cpp_vs_py_sampler_leaf_regression.py b/debug/benchmark_cpp_vs_py_sampler_leaf_regression.py new file mode 100644 index 00000000..943d9b3f --- /dev/null +++ b/debug/benchmark_cpp_vs_py_sampler_leaf_regression.py @@ -0,0 +1,149 @@ +"""Benchmark: C++ sampler loop vs. Python sampler loop -- univariate leaf regression. + +Compares runtime and test-set RMSE across run_cpp=True / False in BARTModel.sample() +with a univariate leaf regression basis (leaf_basis_train with one column). + +DGP: f(X, Z) = tau(X) * Z, where tau(X) is a step function of X[:,0] and +Z is drawn uniform [0, 1]. A constant noise term is added. The leaf basis +passed to the sampler is just Z (shape n x 1). + +Usage: + source venv/bin/activate + python debug/benchmark_cpp_vs_py_sampler_leaf_regression.py +""" + +import time +import numpy as np +from stochtree import BARTModel + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +rng = np.random.default_rng(1234) + +n = 2000 +p = 10 +X = rng.uniform(size=(n, p)) +Z = rng.uniform(size=n) # scalar moderating variable / leaf basis + +# Heterogeneous slope on Z, partitioned by X[:,0] +tau_X = ( + np.where((X[:, 0] >= 0.00) & (X[:, 0] < 0.25), -2.0, 0) + + np.where((X[:, 0] >= 0.25) & (X[:, 0] < 0.50), -1.0, 0) + + np.where((X[:, 0] >= 0.50) & (X[:, 0] < 0.75), 1.0, 0) + + np.where((X[:, 0] >= 0.75) & (X[:, 0] < 1.00), 2.0, 0) +) +f_XZ = tau_X * Z +noise_sd = 1.0 +y = f_XZ + rng.normal(scale=noise_sd, size=n) + +test_frac = 0.2 +n_test = round(test_frac * n) +n_train = n - n_test +test_inds = rng.choice(n, size=n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) + +X_train, X_test = X[train_inds], X[test_inds] +Z_train, Z_test = Z[train_inds], Z[test_inds] +y_train, y_test = y[train_inds], y[test_inds] +f_test = f_XZ[test_inds] + +# Leaf basis matrices (n x 1) +basis_train = Z_train.reshape(-1, 1) +basis_test = Z_test.reshape(-1, 1) + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 +num_trees = 200 +n_reps = 3 + +print( + f"n_train={n_train} n_test={n_test} p={p} " + f"num_trees={num_trees} num_gfr={num_gfr} num_burnin={num_burnin} " + f"num_mcmc={num_mcmc} reps={n_reps}\n" +) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + RMSE +# --------------------------------------------------------------------------- +def run_once(run_cpp: bool, seed: int) -> dict: + m = BARTModel() + t0 = time.perf_counter() + m.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params={"random_seed": seed}, + mean_forest_params={"num_trees": num_trees}, + run_cpp=run_cpp, + ) + elapsed = time.perf_counter() - t0 + + yhat = m.y_hat_test.mean(axis=1) + rmse = float(np.sqrt(np.mean((yhat - y_test) ** 2))) + rmse_f = float(np.sqrt(np.mean((yhat - f_test) ** 2))) + return {"elapsed": elapsed, "rmse": rmse, "rmse_f": rmse_f} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds = [1000 + i for i in range(1, n_reps + 1)] + +results_cpp = [] +results_py = [] + +print("Running C++ sampler (run_cpp=True)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_cpp.append(run_once(run_cpp=True, seed=seed)) + +print("\nRunning Python sampler (run_cpp=False)...") +for i, seed in enumerate(seeds, 1): + print(f" rep {i}/{n_reps}") + results_py.append(run_once(run_cpp=False, seed=seed)) + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +def summarise(results: list) -> dict: + elapsed = [r["elapsed"] for r in results] + rmse = [r["rmse"] for r in results] + rmse_f = [r["rmse_f"] for r in results] + return { + "elapsed_mean": float(np.mean(elapsed)), + "elapsed_sd": float(np.std(elapsed, ddof=1)), + "rmse_mean": float(np.mean(rmse)), + "rmse_f_mean": float(np.mean(rmse_f)), + } + +s_cpp = summarise(results_cpp) +s_py = summarise(results_py) +rows = [("cpp (run_cpp=True)", s_cpp), ("py (run_cpp=False)", s_py)] + +print("\n--- Results ---") +print( + f"{'Sampler':<22} {'Time (s)':>10} {'SD':>10} {'RMSE (obs)':>12} {'RMSE f(X,Z)':>13}" +) +print("-" * 74) +for label, s in rows: + print( + f"{label:<22} {s['elapsed_mean']:>10.3f} {s['elapsed_sd']:>10.3f}" + f" {s['rmse_mean']:>12.4f} {s['rmse_f_mean']:>13.4f}" + ) + +speedup = s_py["elapsed_mean"] / s_cpp["elapsed_mean"] +print(f"\nSpeedup (py / cpp): {speedup:.2f}x") +print( + f"RMSE delta (cpp - py): " + f"obs={s_cpp['rmse_mean'] - s_py['rmse_mean']:.4f} " + f"f={s_cpp['rmse_f_mean'] - s_py['rmse_f_mean']:.4f}" +) diff --git a/debug/benchmark_cpp_vs_r_sampler_leaf_regression.R b/debug/benchmark_cpp_vs_r_sampler_leaf_regression.R new file mode 100644 index 00000000..9cbd1923 --- /dev/null +++ b/debug/benchmark_cpp_vs_r_sampler_leaf_regression.R @@ -0,0 +1,171 @@ +## Benchmark: C++ sampler loop vs. R sampler loop -- univariate leaf regression. +## +## Compares runtime and test-set RMSE across run_cpp = TRUE / FALSE in bart() +## with a univariate leaf regression basis (leaf_basis_train with one column). +## +## DGP: f(X, Z) = tau(X) * Z, where tau(X) is a step function of X[,1] and +## Z is drawn uniform [0, 1]. A constant noise term is added. The leaf basis +## passed to the sampler is just Z (shape n x 1). +## +## Usage: +## Rscript debug/benchmark_cpp_vs_r_sampler_leaf_regression.R +## or source() from an interactive session after devtools::load_all('.') + +library(stochtree) + +# --------------------------------------------------------------------------- +# Data-generating process +# --------------------------------------------------------------------------- +set.seed(1234) + +n <- 2000 +p <- 10 +X <- matrix(runif(n * p), ncol = p) +Z <- runif(n) # scalar moderating variable / leaf basis + +# Heterogeneous slope on Z, partitioned by X[,1] +tau_X <- (((0.00 <= X[, 1]) & (X[, 1] < 0.25)) * + (-2.0) + + ((0.25 <= X[, 1]) & (X[, 1] < 0.50)) * (-1.0) + + ((0.50 <= X[, 1]) & (X[, 1] < 0.75)) * (1.0) + + ((0.75 <= X[, 1]) & (X[, 1] < 1.00)) * (2.0)) +f_XZ <- tau_X * Z +noise_sd <- 1.0 +y <- f_XZ + rnorm(n, 0, noise_sd) + +test_frac <- 0.2 +n_test <- round(test_frac * n) +n_train <- n - n_test +test_inds <- sort(sample(seq_len(n), n_test, replace = FALSE)) +train_inds <- setdiff(seq_len(n), test_inds) + +X_train <- X[train_inds, ] +X_test <- X[test_inds, ] +Z_train <- Z[train_inds] +Z_test <- Z[test_inds] +y_train <- y[train_inds] +y_test <- y[test_inds] +f_test <- f_XZ[test_inds] + +# Leaf basis matrices (n x 1) +basis_train <- matrix(Z_train, ncol = 1) +basis_test <- matrix(Z_test, ncol = 1) + +# --------------------------------------------------------------------------- +# Benchmark settings +# --------------------------------------------------------------------------- +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_trees <- 200 +n_reps <- 3 + +cat(sprintf( + "n_train=%d n_test=%d p=%d num_trees=%d num_gfr=%d num_burnin=%d num_mcmc=%d reps=%d\n\n", + n_train, + n_test, + p, + num_trees, + num_gfr, + num_burnin, + num_mcmc, + n_reps +)) + +# --------------------------------------------------------------------------- +# Helper: run one configuration and return timing + RMSE +# --------------------------------------------------------------------------- +run_once <- function(run_cpp, seed) { + t0 <- proc.time() + m <- bart( + X_train = X_train, + y_train = y_train, + leaf_basis_train = basis_train, + X_test = X_test, + leaf_basis_test = basis_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = list(random_seed = seed), + mean_forest_params = list(num_trees = num_trees), + run_cpp = run_cpp + ) + elapsed <- (proc.time() - t0)[["elapsed"]] + + yhat <- rowMeans(m$y_hat_test) + rmse <- sqrt(mean((yhat - y_test)^2)) + rmse_f <- sqrt(mean((yhat - f_test)^2)) + + list(elapsed = elapsed, rmse = rmse, rmse_f = rmse_f) +} + +# --------------------------------------------------------------------------- +# Run benchmarks +# --------------------------------------------------------------------------- +seeds <- 1000 + seq_len(n_reps) + +results_cpp <- vector("list", n_reps) +results_r <- vector("list", n_reps) + +cat("Running C++ sampler (run_cpp = TRUE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_cpp[[i]] <- run_once(run_cpp = TRUE, seed = seeds[i]) +} + +cat("\nRunning R sampler (run_cpp = FALSE)...\n") +for (i in seq_len(n_reps)) { + cat(sprintf(" rep %d/%d\n", i, n_reps)) + results_r[[i]] <- run_once(run_cpp = FALSE, seed = seeds[i]) +} + +# --------------------------------------------------------------------------- +# Summarise +# --------------------------------------------------------------------------- +summarise <- function(results, label) { + elapsed <- sapply(results, `[[`, "elapsed") + rmse <- sapply(results, `[[`, "rmse") + rmse_f <- sapply(results, `[[`, "rmse_f") + data.frame( + sampler = label, + elapsed_mean = mean(elapsed), + elapsed_sd = sd(elapsed), + rmse_mean = mean(rmse), + rmse_f_mean = mean(rmse_f), + row.names = NULL + ) +} + +res <- rbind( + summarise(results_cpp, "cpp (run_cpp=TRUE)"), + summarise(results_r, "R (run_cpp=FALSE)") +) + +cat("\n--- Results ---\n") +cat(sprintf( + "%-22s %10s %10s %12s %13s\n", + "Sampler", + "Time (s)", + "SD", + "RMSE (obs)", + "RMSE f(X,Z)" +)) +cat(strrep("-", 74), "\n") +for (i in seq_len(nrow(res))) { + cat(sprintf( + "%-22s %10.3f %10.3f %12.4f %13.4f\n", + res$sampler[i], + res$elapsed_mean[i], + res$elapsed_sd[i], + res$rmse_mean[i], + res$rmse_f_mean[i] + )) +} + +speedup <- res$elapsed_mean[2] / res$elapsed_mean[1] +cat(sprintf("\nSpeedup (R / C++): %.2fx\n", speedup)) +cat(sprintf( + "RMSE delta (cpp - R): obs=%.4f f=%.4f\n", + res$rmse_mean[1] - res$rmse_mean[2], + res$rmse_f_mean[1] - res$rmse_f_mean[2] +))