Skip to content

Commit 7c75236

Browse files
authored
Merge pull request #411 from StochasticTree/fix-bart-categorical-mean-only-weights
Fix IndexError in mean-only Python BART with categorical covariates
2 parents 270afe8 + 13d475f commit 7c75236

3 files changed

Lines changed: 38 additions & 7 deletions

File tree

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Bug Fixes
44

5+
* Fix `IndexError` when sampling a mean-only Python BART model with categorical covariates; the excluded-variable weight zero-out is now guarded by the forest-inclusion flags, matching the R implementation [#411](https://github.com/StochasticTree/stochtree/pull/411)
56
* Fix ordinal class prediction bug in R BART [#399](https://github.com/StochasticTree/stochtree/issues/399)
67

78
# stochtree 0.4.4

stochtree/bart.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -990,13 +990,19 @@ def sample(
990990
* variable_weights_adj
991991
)
992992

993-
# Zero out weights for excluded variables
994-
variable_weights_mean[
995-
[variable_subset_mean.count(i) == 0 for i in original_var_indices]
996-
] = 0
997-
variable_weights_variance[
998-
[variable_subset_variance.count(i) == 0 for i in original_var_indices]
999-
] = 0
993+
# Zero out weights for excluded variables. The weight arrays are only
994+
# expanded to processed (post-preprocessing) length inside the
995+
# include_*_forest guards above, so the zero-out must be guarded the same
996+
# way -- otherwise a mean-only (or variance-only) model with categorical
997+
# covariates indexes an unexpanded array and raises. (Matches R's logic.)
998+
if self.include_mean_forest:
999+
variable_weights_mean[
1000+
[variable_subset_mean.count(i) == 0 for i in original_var_indices]
1001+
] = 0
1002+
if self.include_variance_forest:
1003+
variable_weights_variance[
1004+
[variable_subset_variance.count(i) == 0 for i in original_var_indices]
1005+
] = 0
10001006

10011007
# Set num_features_subsample to default, ncol(X_train), if not already set
10021008
if num_features_subsample_mean is None:

test/python/test_bart.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pandas as pd
23
import pytest
34
from sklearn.model_selection import train_test_split
45

@@ -1578,6 +1579,29 @@ def test_cloglog_ordinal_bart_with_gfr(self):
15781579
assert bart_model.y_hat_test.shape == (n_test, num_mcmc)
15791580
assert bart_model.cloglog_cutpoint_samples.shape == (2, num_mcmc)
15801581

1582+
def test_categorical_covariates_mean_only(self):
1583+
"""A mean-only BART model with categorical (one-hot expanded) covariates
1584+
must sample and predict without error.
1585+
1586+
Regression test: the "zero out excluded variable weights" step ran
1587+
outside the include_*_forest guards, so variable_weights_variance was
1588+
never expanded to the processed (one-hot) length for a model without a
1589+
variance forest, and indexing it raised an IndexError.
1590+
"""
1591+
rng = np.random.default_rng(0)
1592+
n = 100
1593+
X_num = rng.uniform(0, 1, (n, 3))
1594+
X = pd.DataFrame(X_num, columns=["a", "b", "c"])
1595+
X["cat"] = pd.Categorical(rng.choice(["x", "y", "z"], size=n))
1596+
y = X_num[:, 0] + rng.normal(scale=0.5, size=n)
1597+
1598+
model = BARTModel()
1599+
# Mean forest only (no variance forest) is the failing configuration.
1600+
model.sample(X_train=X, y_train=y, num_gfr=0, num_burnin=0, num_mcmc=5)
1601+
1602+
preds = model.predict(X)
1603+
assert preds["y_hat"].shape[0] == n
1604+
15811605

15821606
class TestBARTFloat32:
15831607
"""Tests that float32 inputs are accepted and produce valid results (GH #389)."""

0 commit comments

Comments
 (0)