Skip to content

Commit bceda6a

Browse files
committed
test: add unit tests for AdjustedDistributionEstimator
1 parent 101fc4c commit bceda6a

3 files changed

Lines changed: 69 additions & 16 deletions

File tree

dte_adj/__init__.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def find_quantile(quantile, arm):
311311
)
312312

313313
return result
314-
314+
315315
def fit(
316316
self, confoundings: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
317317
) -> "DistributionEstimatorBase":
@@ -446,8 +446,12 @@ def __init__(self, base_model, folds=3):
446446
Returns:
447447
AdjustedDistributionEstimator: An instance of the estimator.
448448
"""
449-
if (not hasattr(base_model, 'predict')) and (not hasattr(base_model, 'predict_proba')):
450-
raise ValueError('base_model should implement either predict_proba or predict')
449+
if (not hasattr(base_model, "predict")) and (
450+
not hasattr(base_model, "predict_proba")
451+
):
452+
raise ValueError(
453+
"Base model should implement either predict_proba or predict"
454+
)
451455
self.base_model = base_model
452456
self.folds = folds
453457
super().__init__()
@@ -496,13 +500,19 @@ def _compute_cumulative_distribution(
496500
continue
497501
model = deepcopy(self.base_model)
498502
model.fit(confounding_train, binominal_train)
499-
subset_prediction[subset_mask] = model.predict_proba(confounding_fit)[
500-
:, 1
501-
]
502-
superset_prediction[superset_mask, i] = model.predict_proba(
503-
confoundings[superset_mask]
504-
)[:, 1]
503+
subset_prediction[subset_mask] = self._compute_model_prediction(
504+
model, confounding_fit
505+
)
506+
superset_prediction[superset_mask, i] = self._compute_model_prediction(
507+
model, confoundings[superset_mask]
508+
)
505509
cumulative_distribution[i] = (
506510
cdf - subset_prediction.mean() + superset_prediction[:, i].mean()
507511
)
508512
return cumulative_distribution, superset_prediction
513+
514+
def _compute_model_prediction(self, model, confoundings: np.ndarray) -> np.ndarray:
515+
if hasattr(model, "predict_proba"):
516+
return model.predict_proba(confoundings)[:, 1]
517+
else:
518+
return model.predict(confoundings)

tests/test_adjusted_estimator.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,25 @@
55

66

77
class TestAdjustedEstimator(unittest.TestCase):
8-
def test_prediction_success(self):
9-
# TODO!
10-
return
8+
def setUp(self):
9+
base_model = MagicMock()
10+
base_model.predict_proba.side_effect = lambda x, y: x
11+
self.estimator = AdjustedDistributionEstimator(base_model, folds=1)
12+
self.confoundings = np.zeros((20, 5))
13+
self.treatment_arms = np.hstack([np.zeros(10), np.ones(10)])
14+
self.outcomes = np.arange(20)
15+
self.estimator.fit(self.confoundings, self.treatment_arms, self.outcomes)
16+
17+
def test_init_fail_incorrect_base_model(self):
18+
# Act, Assert
19+
with self.assertRaises(ValueError) as cm:
20+
AdjustedDistributionEstimator("dummy")
21+
self.assertEqual(
22+
str(cm.exception),
23+
"Base model should implement either predict_proba or predict",
24+
)
1125

12-
def test_prediction_fail_before_fit(self):
26+
def test_predict_fail_before_fit(self):
1327
# Arrange
1428
D = np.zeros(20)
1529
D[:10] = 1
@@ -41,3 +55,32 @@ def test_fit_fail_invalid_input(self):
4155
str(cm.exception),
4256
"The shape of confounding and treatment_arm should be same",
4357
)
58+
59+
def test_compute_cumulative_distribution(self):
60+
# Arrange
61+
mock_model = self.estimator.base_model
62+
mock_model.predict_proba.side_effect = lambda x: np.ones((x.shape[0], 2)) * 0.5
63+
target_treatment_arms = np.zeros(10)
64+
locations = np.arange(10)
65+
66+
# Act
67+
cumulative_distribution, superset_prediction = (
68+
self.estimator._compute_cumulative_distribution(
69+
target_treatment_arms,
70+
locations,
71+
self.confoundings,
72+
self.treatment_arms,
73+
self.outcomes,
74+
)
75+
)
76+
77+
# Assert
78+
self.assertEqual(cumulative_distribution.shape, (10,))
79+
self.assertEqual(superset_prediction.shape, (20, 10))
80+
81+
for i in range(10):
82+
self.assertAlmostEqual(cumulative_distribution[i], (i + 1) / 10, places=2)
83+
84+
for i in range(20):
85+
for j in range(10):
86+
self.assertAlmostEqual(superset_prediction[i, j], 0.5, places=2)

tests/test_distribution_estimator_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_fit_success(self):
133133
np.array_equal(self.estimator.treatment_arms, self.treatment_arms)
134134
)
135135
self.assertTrue(np.array_equal(self.estimator.outcomes, self.outcomes))
136-
136+
137137
def test_fit_invalid_shapes(self):
138138
# Arrange
139139
confoundings_invalid = np.array([[1, 2], [3, 4]])
@@ -149,7 +149,7 @@ def test_fit_invalid_shapes(self):
149149

150150
with self.assertRaises(ValueError):
151151
self.estimator.fit(self.confoundings, self.treatment_arms, outcomes_invalid)
152-
152+
153153
def test_predict_success(self):
154154
# Arrange
155155
treatment_arms_test = np.array([0, 1])
@@ -174,7 +174,7 @@ def test_predict_fail_before_fit(self):
174174
self.assertEqual(
175175
str(cm.exception),
176176
"This estimator has not been trained yet. Please call fit first",
177-
)
177+
)
178178

179179
def test_predict_fail_invalid_arm(self):
180180
# Arrange

0 commit comments

Comments
 (0)