Skip to content

Commit d941702

Browse files
committed
test: improve unittest for simple estimator
1 parent d29b6fb commit d941702

2 files changed

Lines changed: 64 additions & 49 deletions

File tree

dte_adj/__init__.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def find_quantile(quantile, arm):
311311

312312
return result
313313

314-
def predict(self, treatment_arms: np.ndarray, outcomes: np.ndarray) -> np.ndarray:
314+
def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarray:
315315
"""Compute cumulative distribution values.
316316
317317
Args:
@@ -321,7 +321,25 @@ def predict(self, treatment_arms: np.ndarray, outcomes: np.ndarray) -> np.ndarra
321321
Returns:
322322
np.ndarray: Estimated cumulative distribution values for the input.
323323
"""
324-
raise NotImplementedError()
324+
if self.outcomes is None:
325+
raise ValueError(
326+
"This estimator has not been trained yet. Please call fit first"
327+
)
328+
329+
unincluded_arms = set(treatment_arms) - set(self.treatment_arms)
330+
331+
if len(unincluded_arms) > 0:
332+
raise ValueError(
333+
f"This treatment_arms argument contains arms not included in the training data: {unincluded_arms}"
334+
)
335+
336+
return self._compute_cumulative_distribution(
337+
treatment_arms,
338+
locations,
339+
self.confoundings,
340+
self.treatment_arms,
341+
self.outcomes,
342+
)[0]
325343

326344
def _compute_cumulative_distribution(
327345
self,
@@ -375,29 +393,6 @@ def fit(
375393

376394
return self
377395

378-
def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarray:
379-
"""Compute cumulative distribution values.
380-
381-
Args:
382-
treatment_arms (np.ndarray): The index of the treatment arm.
383-
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
384-
385-
Returns:
386-
np.ndarray: Estimated cumulative distribution values for the input.
387-
"""
388-
if self.outcomes is None:
389-
raise ValueError(
390-
"This estimator has not been trained yet. Please call fit first"
391-
)
392-
393-
return self._compute_cumulative_distribution(
394-
treatment_arms,
395-
locations,
396-
self.confoundings,
397-
self.treatment_arms,
398-
self.outcomes,
399-
)[0]
400-
401396
def _compute_cumulative_distribution(
402397
self,
403398
target_treatment_arms: np.ndarray,

tests/test_simple_estimator.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,71 @@
55

66

77
class TestSimpleEstimator(unittest.TestCase):
8-
def test_prediction_success(self):
8+
def setUp(self):
9+
self.estimator = SimpleDistributionEstimator()
10+
self.confoundings = np.zeros((20, 5))
11+
self.treatment_arms = np.hstack([np.zeros(10), np.ones(10)])
12+
self.outcomes = np.arange(20)
13+
self.estimator.fit(self.confoundings, self.treatment_arms, self.outcomes)
14+
15+
def test_fit(self):
16+
self.assertTrue(np.array_equal(self.estimator.confoundings, self.confoundings))
17+
self.assertTrue(
18+
np.array_equal(self.estimator.treatment_arms, self.treatment_arms)
19+
)
20+
self.assertTrue(np.array_equal(self.estimator.outcomes, self.outcomes))
21+
22+
def test_fit_invalid_shapes(self):
923
# Arrange
10-
X = np.arange(20)
11-
D = np.zeros(20)
12-
D[:10] = 1
13-
Y = np.arange(20)
14-
subject = SimpleDistributionEstimator()
15-
subject.fit(X, D, Y)
24+
confoundings_invalid = np.array([[1, 2], [3, 4]])
25+
treatment_arms_invalid = np.array([0, 1])
26+
outcomes_invalid = np.array([0.5, 0.7])
27+
28+
# Assert
29+
with self.assertRaises(ValueError):
30+
self.estimator.fit(confoundings_invalid, self.treatment_arms, self.outcomes)
31+
32+
with self.assertRaises(ValueError):
33+
self.estimator.fit(self.confoundings, treatment_arms_invalid, self.outcomes)
34+
35+
with self.assertRaises(ValueError):
36+
self.estimator.fit(self.confoundings, self.treatment_arms, outcomes_invalid)
37+
38+
def test_predict(self):
39+
# Arrange
40+
treatment_arms_test = np.array([0, 1])
41+
locations_test = np.array([3, 6])
42+
expected_output = np.array([0.4, 0])
1643

1744
# Act
18-
actual = subject.predict(D, Y)
45+
output = self.estimator.predict(treatment_arms_test, locations_test)
1946

2047
# Assert
21-
expected = np.array(
22-
[0.1 * i for i in range(1, 11)] + [0.1 * i for i in range(1, 11)]
23-
)
24-
npt.assert_allclose(actual, expected)
48+
np.testing.assert_array_almost_equal(output, expected_output, decimal=2)
2549

2650
def test_prediction_fail_before_fit(self):
2751
# Arrange
28-
D = np.zeros(20)
29-
D[:10] = 1
30-
Y = np.arange(20)
52+
treatment_arms_test = np.array([0, 1])
53+
locations_test = np.array([3, 6])
3154
subject = SimpleDistributionEstimator()
3255

3356
# Act, Assert
3457
with self.assertRaises(ValueError) as cm:
35-
subject.predict(D, Y)
58+
subject.predict(treatment_arms_test, locations_test)
3659
self.assertEqual(
3760
str(cm.exception),
3861
"This estimator has not been trained yet. Please call fit first",
3962
)
4063

41-
def test_fit_fail_invalid_input(self):
64+
def test_prediction_fail_invalid_arm(self):
4265
# Arrange
43-
X = np.arange(20)
44-
D = np.zeros(10)
45-
D[:10] = 1
46-
Y = np.arange(20)
47-
subject = SimpleDistributionEstimator()
66+
treatment_arms_invalid = np.array([2])
67+
locations_test = np.array([3, 6])
4868

4969
# Act, Assert
5070
with self.assertRaises(ValueError) as cm:
51-
subject.fit(X, D, Y)
71+
self.estimator.predict(treatment_arms_invalid, locations_test)
5272
self.assertEqual(
5373
str(cm.exception),
54-
"The shape of confounding and treatment_arm should be same",
74+
"This treatment_arms argument contains arms not included in the training data: {2}",
5575
)

0 commit comments

Comments
 (0)