Skip to content

Commit 7fb469c

Browse files
authored
Merge pull request #391 from StochasticTree/worktree-float32-fix-backport-0.4.3
Add predict() coverage to float32 test suite
2 parents 181ec82 + 8b0cd70 commit 7fb469c

2 files changed

Lines changed: 36 additions & 0 deletions

File tree

test/python/test_bart.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,8 @@ def test_bart_float32_runs(self):
16041604
)
16051605
assert bart_model.y_hat_train.shape == (self.n_train, self.num_mcmc)
16061606
assert bart_model.y_hat_test.shape == (self.n_test, self.num_mcmc)
1607+
preds = bart_model.predict(X=self.X_test)
1608+
assert preds["y_hat"].shape == (self.n_test, self.num_mcmc)
16071609

16081610
def test_bart_float32_matches_float64(self):
16091611
"""float32 and float64 inputs with the same seed should produce close results."""
@@ -1628,6 +1630,9 @@ def test_bart_float32_matches_float64(self):
16281630
general_params={"random_seed": 1},
16291631
)
16301632
np.testing.assert_allclose(bart32.y_hat_train, bart64.y_hat_train, rtol=1e-5)
1633+
pred32 = bart32.predict(X=self.X_test)
1634+
pred64 = bart32.predict(X=self.X_test.astype(np.float64))
1635+
np.testing.assert_allclose(pred32["y_hat"], pred64["y_hat"], rtol=1e-5)
16311636

16321637
def test_bart_float32_leaf_basis(self):
16331638
rng = np.random.default_rng(7)
@@ -1646,6 +1651,8 @@ def test_bart_float32_leaf_basis(self):
16461651
)
16471652
assert bart_model.y_hat_train.shape == (self.n_train, self.num_mcmc)
16481653
assert bart_model.y_hat_test.shape == (self.n_test, self.num_mcmc)
1654+
preds = bart_model.predict(X=self.X_test, leaf_basis=basis_test)
1655+
assert preds["y_hat"].shape == (self.n_test, self.num_mcmc)
16491656

16501657
def test_bart_float32_leaf_basis_matches_float64(self):
16511658
rng = np.random.default_rng(7)
@@ -1663,6 +1670,9 @@ def test_bart_float32_leaf_basis_matches_float64(self):
16631670
X_test=self.X_test.astype(np.float64),
16641671
leaf_basis_test=basis_test.astype(np.float64), **common)
16651672
np.testing.assert_allclose(bart32.y_hat_train, bart64.y_hat_train, rtol=1e-5)
1673+
pred32 = bart32.predict(X=self.X_test, leaf_basis=basis_test)
1674+
pred64 = bart32.predict(X=self.X_test.astype(np.float64), leaf_basis=basis_test.astype(np.float64))
1675+
np.testing.assert_allclose(pred32["y_hat"], pred64["y_hat"], rtol=1e-5)
16661676

16671677
def test_bart_float32_rfx(self):
16681678
rng = np.random.default_rng(7)
@@ -1686,6 +1696,8 @@ def test_bart_float32_rfx(self):
16861696
)
16871697
assert bart_model.y_hat_train.shape == (self.n_train, self.num_mcmc)
16881698
assert bart_model.y_hat_test.shape == (self.n_test, self.num_mcmc)
1699+
preds = bart_model.predict(X=self.X_test, rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test)
1700+
assert preds["y_hat"].shape == (self.n_test, self.num_mcmc)
16891701

16901702
def test_bart_float32_rfx_matches_float64(self):
16911703
rng = np.random.default_rng(7)
@@ -1706,3 +1718,6 @@ def test_bart_float32_rfx_matches_float64(self):
17061718
rfx_basis_train=rfx_basis_train.astype(np.float64),
17071719
rfx_basis_test=rfx_basis_test.astype(np.float64), **common)
17081720
np.testing.assert_allclose(bart32.y_hat_train, bart64.y_hat_train, rtol=1e-4)
1721+
pred32 = bart32.predict(X=self.X_test, rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test)
1722+
pred64 = bart32.predict(X=self.X_test.astype(np.float64), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test.astype(np.float64))
1723+
np.testing.assert_allclose(pred32["y_hat"], pred64["y_hat"], rtol=1e-4)

test/python/test_bcf.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,9 @@ def test_bcf_float32_with_propensity(self):
11591159
assert bcf_model.y_hat_test.shape == (self.n_test, self.num_mcmc)
11601160
assert bcf_model.tau_hat_train.shape == (self.n_train, self.num_mcmc)
11611161
assert bcf_model.tau_hat_test.shape == (self.n_test, self.num_mcmc)
1162+
preds = bcf_model.predict(X=self.X_test, Z=self.Z_test, propensity=self.pi_test)
1163+
assert preds["y_hat"].shape == (self.n_test, self.num_mcmc)
1164+
assert preds["tau_hat"].shape == (self.n_test, self.num_mcmc)
11621165

11631166
def test_bcf_float32_with_propensity_matches_float64(self):
11641167
common = dict(num_gfr=5, num_burnin=0, num_mcmc=self.num_mcmc, general_params={"random_seed": 1})
@@ -1175,6 +1178,10 @@ def test_bcf_float32_with_propensity_matches_float64(self):
11751178
Z_test=self.Z_test.astype(np.float64),
11761179
propensity_test=self.pi_test.astype(np.float64), **common)
11771180
np.testing.assert_allclose(bcf32.y_hat_train, bcf64.y_hat_train, rtol=1e-4)
1181+
pred32 = bcf32.predict(X=self.X_test, Z=self.Z_test, propensity=self.pi_test)
1182+
pred64 = bcf32.predict(X=self.X_test.astype(np.float64), Z=self.Z_test.astype(np.float64),
1183+
propensity=self.pi_test.astype(np.float64))
1184+
np.testing.assert_allclose(pred32["y_hat"], pred64["y_hat"], rtol=1e-4)
11781185

11791186
def test_bcf_float32_no_propensity(self):
11801187
"""float32 Z, y, X with internal propensity estimation."""
@@ -1191,6 +1198,8 @@ def test_bcf_float32_no_propensity(self):
11911198
)
11921199
assert bcf_model.y_hat_train.shape == (self.n_train, self.num_mcmc)
11931200
assert bcf_model.y_hat_test.shape == (self.n_test, self.num_mcmc)
1201+
preds = bcf_model.predict(X=self.X_test, Z=self.Z_test)
1202+
assert preds["y_hat"].shape == (self.n_test, self.num_mcmc)
11941203

11951204
def test_bcf_float32_no_propensity_matches_float64(self):
11961205
common = dict(num_gfr=5, num_burnin=0, num_mcmc=self.num_mcmc, general_params={"random_seed": 1})
@@ -1204,6 +1213,9 @@ def test_bcf_float32_no_propensity_matches_float64(self):
12041213
X_test=self.X_test.astype(np.float64),
12051214
Z_test=self.Z_test.astype(np.float64), **common)
12061215
np.testing.assert_allclose(bcf32.y_hat_train, bcf64.y_hat_train, rtol=1e-4)
1216+
pred32 = bcf32.predict(X=self.X_test, Z=self.Z_test)
1217+
pred64 = bcf32.predict(X=self.X_test.astype(np.float64), Z=self.Z_test.astype(np.float64))
1218+
np.testing.assert_allclose(pred32["y_hat"], pred64["y_hat"], rtol=1e-4)
12071219

12081220
def test_bcf_float32_rfx(self):
12091221
"""float32 rfx_basis_train and rfx_basis_test."""
@@ -1232,6 +1244,9 @@ def test_bcf_float32_rfx(self):
12321244
)
12331245
assert bcf_model.y_hat_train.shape == (self.n_train, self.num_mcmc)
12341246
assert bcf_model.y_hat_test.shape == (self.n_test, self.num_mcmc)
1247+
preds = bcf_model.predict(X=self.X_test, Z=self.Z_test, propensity=self.pi_test,
1248+
rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test)
1249+
assert preds["y_hat"].shape == (self.n_test, self.num_mcmc)
12351250

12361251
def test_bcf_float32_rfx_matches_float64(self):
12371252
rng = np.random.default_rng(7)
@@ -1256,3 +1271,9 @@ def test_bcf_float32_rfx_matches_float64(self):
12561271
rfx_basis_train=rfx_basis_train.astype(np.float64),
12571272
rfx_basis_test=rfx_basis_test.astype(np.float64), **common)
12581273
np.testing.assert_allclose(bcf32.y_hat_train, bcf64.y_hat_train, rtol=1e-4)
1274+
pred32 = bcf32.predict(X=self.X_test, Z=self.Z_test, propensity=self.pi_test,
1275+
rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test)
1276+
pred64 = bcf32.predict(X=self.X_test.astype(np.float64), Z=self.Z_test.astype(np.float64),
1277+
propensity=self.pi_test.astype(np.float64),
1278+
rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test.astype(np.float64))
1279+
np.testing.assert_allclose(pred32["y_hat"], pred64["y_hat"], rtol=1e-4)

0 commit comments

Comments
 (0)