Skip to content

Commit 101fc4c

Browse files
committed
test: add unittests for DistributionEstimatorBase
1 parent d941702 commit 101fc4c

4 files changed

Lines changed: 225 additions & 154 deletions

File tree

dte_adj/__init__.py

Lines changed: 33 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from typing import Tuple
33
from scipy.stats import norm
44
from copy import deepcopy
5-
from .util import compute_confidence_intervals, find_le
5+
from abc import ABC
6+
from .util import compute_confidence_intervals
67

78
__all__ = ["SimpleDistributionEstimator", "AdjustedDistributionEstimator"]
89

910

10-
class DistributionEstimatorBase(object):
11+
class DistributionEstimatorBase(ABC):
1112
"""A mixin including several convenience functions to compute and display distribution functions."""
1213

1314
def __init__(self):
@@ -310,6 +311,33 @@ def find_quantile(quantile, arm):
310311
)
311312

312313
return result
314+
315+
def fit(
316+
self, confoundings: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
317+
) -> "DistributionEstimatorBase":
318+
"""Train the DistributionEstimatorBase.
319+
320+
Args:
321+
confoundings (np.ndarray): Pre-treatment covariates.
322+
treatment_arms (np.ndarray): The index of the treatment arm.
323+
outcomes (np.ndarray): Scalar-valued observed outcome.
324+
325+
Returns:
326+
DistributionEstimatorBase: The fitted estimator.
327+
"""
328+
if confoundings.shape[0] != treatment_arms.shape[0]:
329+
raise ValueError(
330+
"The shape of confounding and treatment_arm should be same"
331+
)
332+
333+
if confoundings.shape[0] != outcomes.shape[0]:
334+
raise ValueError("The shape of confounding and outcome should be same")
335+
336+
self.confoundings = confoundings
337+
self.treatment_arms = treatment_arms
338+
self.outcomes = outcomes
339+
340+
return self
313341

314342
def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarray:
315343
"""Compute cumulative distribution values.
@@ -366,33 +394,6 @@ def __init__(self):
366394
"""
367395
super().__init__()
368396

369-
def fit(
370-
self, confoundings: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
371-
) -> "SimpleDistributionEstimator":
372-
"""Train the SimpleDistributionEstimator.
373-
374-
Args:
375-
confoundings (np.ndarray): Pre-treatment covariates.
376-
treatment_arms (np.ndarray): The index of the treatment arm.
377-
outcomes (np.ndarray): Scalar-valued observed outcome.
378-
379-
Returns:
380-
SimpleDistributionEstimator: The fitted estimator.
381-
"""
382-
if confoundings.shape[0] != treatment_arms.shape[0]:
383-
raise ValueError(
384-
"The shape of confounding and treatment_arm should be same"
385-
)
386-
387-
if confoundings.shape[0] != outcomes.shape[0]:
388-
raise ValueError("The shape of confounding and outcome should be same")
389-
390-
self.confoundings = confoundings
391-
self.treatment_arms = treatment_arms
392-
self.outcomes = outcomes
393-
394-
return self
395-
396397
def _compute_cumulative_distribution(
397398
self,
398399
target_treatment_arms: np.ndarray,
@@ -427,7 +428,7 @@ def _compute_cumulative_distribution(
427428
cumulative_distribution = np.zeros(locations.shape)
428429
for i, (outcome, arm) in enumerate(zip(locations, target_treatment_arms)):
429430
cumulative_distribution[i] = (
430-
find_le(d_outcome[arm], outcome) + 1
431+
np.searchsorted(d_outcome[arm], outcome, side="right")
431432
) / d_outcome[arm].shape[0]
432433
return cumulative_distribution, np.zeros((n_obs, n_loc))
433434

@@ -445,60 +446,12 @@ def __init__(self, base_model, folds=3):
445446
Returns:
446447
AdjustedDistributionEstimator: An instance of the estimator.
447448
"""
449+
if (not hasattr(base_model, 'predict')) and (not hasattr(base_model, 'predict_proba')):
450+
raise ValueError('base_model should implement either predict_proba or predict')
448451
self.base_model = base_model
449452
self.folds = folds
450453
super().__init__()
451454

452-
def fit(
453-
self, confoundings: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
454-
) -> "AdjustedDistributionEstimator":
455-
"""Train the AdjustedDistributionEstimator.
456-
457-
Args:
458-
confoundings (np.ndarray): Pre-treatment covariates.
459-
treatment_arms (np.ndarray): The index of the treatment arm.
460-
outcomes (np.ndarray): Scalar-valued observed outcome.
461-
462-
Returns:
463-
AdjustedDistributionEstimator: The fitted estimator.
464-
"""
465-
if confoundings.shape[0] != treatment_arms.shape[0]:
466-
raise ValueError(
467-
"The shape of confounding and treatment_arm should be same"
468-
)
469-
470-
if confoundings.shape[0] != outcomes.shape[0]:
471-
raise ValueError("The shape of confounding and outcome should be same")
472-
473-
self.confoundings = confoundings
474-
self.treatment_arms = treatment_arms
475-
self.outcomes = outcomes
476-
477-
return self
478-
479-
def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarray:
480-
"""Compute cumulative distribution values.
481-
482-
Args:
483-
treatment_arms (np.ndarray): The index of the treatment arm.
484-
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
485-
486-
Returns:
487-
np.ndarray: Estimated cumulative distribution values for the input.
488-
"""
489-
if self.outcomes is None:
490-
raise ValueError(
491-
"This estimator has not been trained yet. Please call fit first"
492-
)
493-
494-
return self._compute_cumulative_distribution(
495-
treatment_arms,
496-
locations,
497-
self.confoundings,
498-
self.treatment_arms,
499-
self.outcomes,
500-
)[0]
501-
502455
def _compute_cumulative_distribution(
503456
self,
504457
target_treatment_arms: np.ndarray,

dte_adj/util.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy.stats import norm
3+
from typing import Tuple
34

45

56
def compute_confidence_intervals(
@@ -16,7 +17,7 @@ def compute_confidence_intervals(
1617
alpha: 0.05,
1718
variance_type="moment",
1819
n_bootstrap=500,
19-
):
20+
) -> Tuple[np.ndarray, np.ndarray]:
2021
"""Computes the confidence intervals of distribution parameters.
2122
2223
Args:
@@ -106,25 +107,3 @@ def compute_confidence_intervals(
106107
return vec_dte_lower_simple, vec_dte_upper_simple
107108
else:
108109
raise ValueError(f"Invalid variance type was speficied: {variance_type}")
109-
110-
111-
def find_le(array: np.ndarray, threshold):
112-
"""Find the rightmost value less than or equal to threshold in a sorted array
113-
114-
Args:
115-
array (np.ndarray): The sorted array to search in.
116-
threshold (float): The threshold value.
117-
118-
Returns:
119-
int: The index where the value first exceeds the threshold.
120-
"""
121-
low, high = 0, array.shape[0] - 1
122-
result = -1
123-
while low <= high:
124-
mid = (low + high) // 2
125-
if array[mid] <= threshold:
126-
result = mid
127-
low = mid + 1
128-
else:
129-
high = mid - 1
130-
return result
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import unittest
2+
import numpy as np
3+
from unittest.mock import patch, MagicMock
4+
from dte_adj import DistributionEstimatorBase
5+
6+
7+
def compute_cumulative_distribution(
8+
target_treatment_arms: np.ndarray,
9+
locations: np.ndarray,
10+
confoundings: np.ndarray,
11+
treatment_arms: np.ndarray,
12+
outcomes: np.array,
13+
) -> np.ndarray:
14+
"""Mock implementation for testing purposes."""
15+
return np.linspace(
16+
0, 0.9, locations.shape[0]
17+
) + target_treatment_arms * 0.1, np.zeros((outcomes.shape[0], locations.shape[0]))
18+
19+
20+
class MockDistributionEstimator(DistributionEstimatorBase):
21+
def __init__(
22+
self, mock_compute_cumulative_distribution=compute_cumulative_distribution
23+
):
24+
super().__init__()
25+
self.compute_cumulative_distribution = MagicMock()
26+
self.compute_cumulative_distribution.side_effect = (
27+
mock_compute_cumulative_distribution
28+
)
29+
30+
"""Mock class to implement _compute_cumulative_distribution for testing."""
31+
32+
def _compute_cumulative_distribution(
33+
self,
34+
target_treatment_arms: np.ndarray,
35+
locations: np.ndarray,
36+
confoundings: np.ndarray,
37+
treatment_arms: np.ndarray,
38+
outcomes: np.array,
39+
) -> np.ndarray:
40+
return self.compute_cumulative_distribution(
41+
target_treatment_arms, locations, confoundings, treatment_arms, outcomes
42+
)
43+
44+
45+
def compute_confidence_intervals(*args, **kwargs):
46+
"""Mock function for compute_confidence_intervals."""
47+
size = len(kwargs["vec_loc"])
48+
lower_bound = np.full(size, 0.1)
49+
upper_bound = np.full(size, 0.9)
50+
return lower_bound, upper_bound
51+
52+
53+
class TestDistributionEstimatorBase(unittest.TestCase):
54+
def setUp(self):
55+
self.estimator = MockDistributionEstimator()
56+
self.confoundings = np.zeros((20, 5))
57+
self.treatment_arms = np.hstack([np.zeros(10), np.ones(10)])
58+
self.outcomes = np.arange(20)
59+
self.estimator.fit(self.confoundings, self.treatment_arms, self.outcomes)
60+
61+
def test_initialization(self):
62+
# Arrange
63+
base_estimator = MockDistributionEstimator()
64+
65+
# Assert
66+
self.assertIsNone(base_estimator.confoundings)
67+
self.assertIsNone(base_estimator.treatment_arms)
68+
self.assertIsNone(base_estimator.outcomes)
69+
70+
@patch(
71+
"dte_adj.compute_confidence_intervals", side_effect=compute_confidence_intervals
72+
)
73+
def test_predict_dte(self, mock_compute_confidence_intervals):
74+
# Arrange
75+
target_treatment_arm = 1
76+
control_treatment_arm = 0
77+
locations = np.arange(20)
78+
79+
# Act
80+
dte, lower_bound, upper_bound = self.estimator.predict_dte(
81+
target_treatment_arm, control_treatment_arm, locations
82+
)
83+
84+
# Assert
85+
np.testing.assert_array_almost_equal(dte, np.full(locations.shape, 0.1))
86+
np.testing.assert_array_almost_equal(lower_bound, np.full(locations.shape, 0.1))
87+
np.testing.assert_array_almost_equal(upper_bound, np.full(locations.shape, 0.9))
88+
self.estimator.compute_cumulative_distribution.assert_called()
89+
90+
@patch(
91+
"dte_adj.compute_confidence_intervals", side_effect=compute_confidence_intervals
92+
)
93+
def test_predict_pte(self, mock_compute_confidence_intervals):
94+
# Arrange
95+
target_treatment_arm = 1
96+
control_treatment_arm = 0
97+
locations = np.arange(20)
98+
width = 0.1
99+
100+
# Act
101+
pte, lower_bound, upper_bound = self.estimator.predict_pte(
102+
target_treatment_arm, control_treatment_arm, width, locations
103+
)
104+
105+
# Assert
106+
np.testing.assert_array_almost_equal(pte, np.full(locations.shape, 0))
107+
np.testing.assert_array_almost_equal(lower_bound, np.full(locations.shape, 0.1))
108+
np.testing.assert_array_almost_equal(upper_bound, np.full(locations.shape, 0.9))
109+
self.estimator.compute_cumulative_distribution.assert_called()
110+
111+
def test_predict_qte(self):
112+
# Arrange
113+
target_treatment_arm = 1
114+
control_treatment_arm = 0
115+
quantiles = np.array([0.1 * i for i in range(1, 10)])
116+
expected_qte = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
117+
118+
# Act
119+
qte, lower_bound, upper_bound = self.estimator.predict_qte(
120+
target_treatment_arm, control_treatment_arm, quantiles, n_bootstrap=5
121+
)
122+
123+
# Assert
124+
np.testing.assert_array_almost_equal(qte, expected_qte)
125+
np.testing.assert_array_almost_equal(lower_bound.shape, quantiles.shape)
126+
np.testing.assert_array_almost_equal(lower_bound.shape, quantiles.shape)
127+
self.estimator.compute_cumulative_distribution.assert_called()
128+
129+
def test_fit_success(self):
130+
# Assert
131+
self.assertTrue(np.array_equal(self.estimator.confoundings, self.confoundings))
132+
self.assertTrue(
133+
np.array_equal(self.estimator.treatment_arms, self.treatment_arms)
134+
)
135+
self.assertTrue(np.array_equal(self.estimator.outcomes, self.outcomes))
136+
137+
def test_fit_invalid_shapes(self):
138+
# Arrange
139+
confoundings_invalid = np.array([[1, 2], [3, 4]])
140+
treatment_arms_invalid = np.array([0, 1])
141+
outcomes_invalid = np.array([0.5, 0.7])
142+
143+
# Assert
144+
with self.assertRaises(ValueError):
145+
self.estimator.fit(confoundings_invalid, self.treatment_arms, self.outcomes)
146+
147+
with self.assertRaises(ValueError):
148+
self.estimator.fit(self.confoundings, treatment_arms_invalid, self.outcomes)
149+
150+
with self.assertRaises(ValueError):
151+
self.estimator.fit(self.confoundings, self.treatment_arms, outcomes_invalid)
152+
153+
def test_predict_success(self):
154+
# Arrange
155+
treatment_arms_test = np.array([0, 1])
156+
locations_test = np.array([3, 6])
157+
expected_output = np.array([0.4, 0])
158+
159+
# Act
160+
output = self.estimator.predict(treatment_arms_test, locations_test)
161+
162+
# Assert
163+
self.estimator.compute_cumulative_distribution.assert_called_once()
164+
165+
def test_predict_fail_before_fit(self):
166+
# Arrange
167+
treatment_arms_test = np.array([0, 1])
168+
locations_test = np.array([3, 6])
169+
subject = MockDistributionEstimator()
170+
171+
# Act, Assert
172+
with self.assertRaises(ValueError) as cm:
173+
subject.predict(treatment_arms_test, locations_test)
174+
self.assertEqual(
175+
str(cm.exception),
176+
"This estimator has not been trained yet. Please call fit first",
177+
)
178+
179+
def test_predict_fail_invalid_arm(self):
180+
# Arrange
181+
treatment_arms_invalid = np.array([2])
182+
locations_test = np.array([3, 6])
183+
184+
# Act, Assert
185+
with self.assertRaises(ValueError) as cm:
186+
self.estimator.predict(treatment_arms_invalid, locations_test)
187+
self.assertEqual(
188+
str(cm.exception),
189+
"This treatment_arms argument contains arms not included in the training data: {2}",
190+
)

0 commit comments

Comments
 (0)