@@ -1159,6 +1159,9 @@ def test_bcf_float32_with_propensity(self):
11591159 assert bcf_model .y_hat_test .shape == (self .n_test , self .num_mcmc )
11601160 assert bcf_model .tau_hat_train .shape == (self .n_train , self .num_mcmc )
11611161 assert bcf_model .tau_hat_test .shape == (self .n_test , self .num_mcmc )
1162+ preds = bcf_model .predict (X = self .X_test , Z = self .Z_test , propensity = self .pi_test )
1163+ assert preds ["y_hat" ].shape == (self .n_test , self .num_mcmc )
1164+ assert preds ["tau_hat" ].shape == (self .n_test , self .num_mcmc )
11621165
11631166 def test_bcf_float32_with_propensity_matches_float64 (self ):
11641167 common = dict (num_gfr = 5 , num_burnin = 0 , num_mcmc = self .num_mcmc , general_params = {"random_seed" : 1 })
@@ -1175,6 +1178,10 @@ def test_bcf_float32_with_propensity_matches_float64(self):
11751178 Z_test = self .Z_test .astype (np .float64 ),
11761179 propensity_test = self .pi_test .astype (np .float64 ), ** common )
11771180 np .testing .assert_allclose (bcf32 .y_hat_train , bcf64 .y_hat_train , rtol = 1e-4 )
1181+ pred32 = bcf32 .predict (X = self .X_test , Z = self .Z_test , propensity = self .pi_test )
1182+ pred64 = bcf32 .predict (X = self .X_test .astype (np .float64 ), Z = self .Z_test .astype (np .float64 ),
1183+ propensity = self .pi_test .astype (np .float64 ))
1184+ np .testing .assert_allclose (pred32 ["y_hat" ], pred64 ["y_hat" ], rtol = 1e-4 )
11781185
11791186 def test_bcf_float32_no_propensity (self ):
11801187 """float32 Z, y, X with internal propensity estimation."""
@@ -1191,6 +1198,8 @@ def test_bcf_float32_no_propensity(self):
11911198 )
11921199 assert bcf_model .y_hat_train .shape == (self .n_train , self .num_mcmc )
11931200 assert bcf_model .y_hat_test .shape == (self .n_test , self .num_mcmc )
1201+ preds = bcf_model .predict (X = self .X_test , Z = self .Z_test )
1202+ assert preds ["y_hat" ].shape == (self .n_test , self .num_mcmc )
11941203
11951204 def test_bcf_float32_no_propensity_matches_float64 (self ):
11961205 common = dict (num_gfr = 5 , num_burnin = 0 , num_mcmc = self .num_mcmc , general_params = {"random_seed" : 1 })
@@ -1204,6 +1213,9 @@ def test_bcf_float32_no_propensity_matches_float64(self):
12041213 X_test = self .X_test .astype (np .float64 ),
12051214 Z_test = self .Z_test .astype (np .float64 ), ** common )
12061215 np .testing .assert_allclose (bcf32 .y_hat_train , bcf64 .y_hat_train , rtol = 1e-4 )
1216+ pred32 = bcf32 .predict (X = self .X_test , Z = self .Z_test )
1217+ pred64 = bcf32 .predict (X = self .X_test .astype (np .float64 ), Z = self .Z_test .astype (np .float64 ))
1218+ np .testing .assert_allclose (pred32 ["y_hat" ], pred64 ["y_hat" ], rtol = 1e-4 )
12071219
12081220 def test_bcf_float32_rfx (self ):
12091221 """float32 rfx_basis_train and rfx_basis_test."""
@@ -1232,6 +1244,9 @@ def test_bcf_float32_rfx(self):
12321244 )
12331245 assert bcf_model .y_hat_train .shape == (self .n_train , self .num_mcmc )
12341246 assert bcf_model .y_hat_test .shape == (self .n_test , self .num_mcmc )
1247+ preds = bcf_model .predict (X = self .X_test , Z = self .Z_test , propensity = self .pi_test ,
1248+ rfx_group_ids = group_ids_test , rfx_basis = rfx_basis_test )
1249+ assert preds ["y_hat" ].shape == (self .n_test , self .num_mcmc )
12351250
12361251 def test_bcf_float32_rfx_matches_float64 (self ):
12371252 rng = np .random .default_rng (7 )
@@ -1256,3 +1271,9 @@ def test_bcf_float32_rfx_matches_float64(self):
12561271 rfx_basis_train = rfx_basis_train .astype (np .float64 ),
12571272 rfx_basis_test = rfx_basis_test .astype (np .float64 ), ** common )
12581273 np .testing .assert_allclose (bcf32 .y_hat_train , bcf64 .y_hat_train , rtol = 1e-4 )
1274+ pred32 = bcf32 .predict (X = self .X_test , Z = self .Z_test , propensity = self .pi_test ,
1275+ rfx_group_ids = group_ids_test , rfx_basis = rfx_basis_test )
1276+ pred64 = bcf32 .predict (X = self .X_test .astype (np .float64 ), Z = self .Z_test .astype (np .float64 ),
1277+ propensity = self .pi_test .astype (np .float64 ),
1278+ rfx_group_ids = group_ids_test , rfx_basis = rfx_basis_test .astype (np .float64 ))
1279+ np .testing .assert_allclose (pred32 ["y_hat" ], pred64 ["y_hat" ], rtol = 1e-4 )
0 commit comments