Skip to content

Commit 59912fb

Browse files
committed
test: cover stratified bootstrap and quantiles=None behavior
- test_predict_qte_preserves_per_stratum_counts: spy on _compute_qtes and assert every bootstrap replicate has the same per-stratum counts as the original sample. This would fail under a plain bootstrap. - test_predict_qte_default_quantiles: predict_qte without quantiles returns shape (9,) for the [0.1, ..., 0.9] default. - test_predict_qte_rejects_out_of_range_quantiles: values at the 0 / 1 boundary raise ValueError.
1 parent fa89916 commit 59912fb

1 file changed

Lines changed: 64 additions & 0 deletions

File tree

tests/test_stratified_estimators.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,67 @@ def test_different_alpha_values(self):
214214
width_010 = upper_010 - lower_010
215215

216216
self.assertTrue(np.all(width_010 < width_005))
217+
218+
def test_predict_qte_preserves_per_stratum_counts(self):
219+
# Stratified bootstrap must preserve per-stratum sample counts in every
220+
# bootstrap replicate. This would fail under a plain (unstratified) bootstrap.
221+
estimator = SimpleStratifiedDistributionEstimator()
222+
estimator.fit(self.X, self.W, self.Y, self.strata)
223+
224+
original_counts = np.bincount(self.strata.astype(int))
225+
226+
captured_strata = []
227+
original_compute = estimator._compute_qtes
228+
229+
def spy_compute(*args, **kwargs):
230+
captured_strata.append(args[-1])
231+
return original_compute(*args, **kwargs)
232+
233+
estimator._compute_qtes = spy_compute
234+
try:
235+
estimator.predict_qte(
236+
target_treatment_arm=1,
237+
control_treatment_arm=0,
238+
quantiles=np.array([0.5]),
239+
n_bootstrap=5,
240+
display_progress=False,
241+
)
242+
finally:
243+
estimator._compute_qtes = original_compute
244+
245+
# 1 call for the point estimate + 5 bootstrap calls
246+
self.assertEqual(len(captured_strata), 6)
247+
for strata in captured_strata:
248+
np.testing.assert_array_equal(
249+
np.bincount(strata.astype(int)), original_counts
250+
)
251+
252+
def test_predict_qte_default_quantiles(self):
253+
# quantiles=None should default to [0.1, 0.2, ..., 0.9] without erroring.
254+
estimator = SimpleStratifiedDistributionEstimator()
255+
estimator.fit(self.X, self.W, self.Y, self.strata)
256+
257+
qte, lower, upper = estimator.predict_qte(
258+
target_treatment_arm=1,
259+
control_treatment_arm=0,
260+
n_bootstrap=10,
261+
display_progress=False,
262+
)
263+
264+
self.assertEqual(qte.shape, (9,))
265+
self.assertEqual(lower.shape, (9,))
266+
self.assertEqual(upper.shape, (9,))
267+
self.assertTrue(np.all(lower <= upper))
268+
269+
def test_predict_qte_rejects_out_of_range_quantiles(self):
270+
estimator = SimpleStratifiedDistributionEstimator()
271+
estimator.fit(self.X, self.W, self.Y, self.strata)
272+
273+
with self.assertRaises(ValueError):
274+
estimator.predict_qte(
275+
1, 0, quantiles=np.array([0.0, 0.5]), n_bootstrap=5, display_progress=False
276+
)
277+
with self.assertRaises(ValueError):
278+
estimator.predict_qte(
279+
1, 0, quantiles=np.array([0.5, 1.0]), n_bootstrap=5, display_progress=False
280+
)

0 commit comments

Comments
 (0)