Skip to content

Commit 90491b9

Browse files
committed
test: fix unittests
1 parent cbcbb6e commit 90491b9

5 files changed

Lines changed: 39 additions & 16 deletions

File tree

dte_adj/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,12 +355,12 @@ def fit(
355355
SimpleDistributionEstimator: The fitted estimator.
356356
"""
357357
if confoundings.shape[0] != treatment_arms.shape[0]:
358-
raise RuntimeError(
358+
raise ValueError(
359359
"The shape of confounding and treatment_arm should be same"
360360
)
361361

362362
if confoundings.shape[0] != outcomes.shape[0]:
363-
raise RuntimeError("The shape of confounding and outcome should be same")
363+
raise ValueError("The shape of confounding and outcome should be same")
364364

365365
self.confoundings = confoundings
366366
self.treatment_arms = treatment_arms
@@ -379,7 +379,7 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
379379
np.ndarray: Estimated cumulative distribution values for the input.
380380
"""
381381
if self.outcomes is None:
382-
raise RuntimeError(
382+
raise ValueError(
383383
"This estimator has not been trained yet. Please call fit first"
384384
)
385385

@@ -461,12 +461,12 @@ def fit(
461461
AdjustedDistributionEstimator: The fitted estimator.
462462
"""
463463
if confoundings.shape[0] != treatment_arms.shape[0]:
464-
raise RuntimeError(
464+
raise ValueError(
465465
"The shape of confounding and treatment_arm should be same"
466466
)
467467

468468
if confoundings.shape[0] != outcomes.shape[0]:
469-
raise RuntimeError("The shape of confounding and outcome should be same")
469+
raise ValueError("The shape of confounding and outcome should be same")
470470

471471
self.confoundings = confoundings
472472
self.treatment_arms = treatment_arms
@@ -485,7 +485,7 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
485485
np.ndarray: Estimated cumulative distribution values for the input.
486486
"""
487487
if self.outcomes is None:
488-
raise RuntimeError(
488+
raise ValueError(
489489
"This estimator has not been trained yet. Please call fit first"
490490
)
491491

dte_adj/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def compute_confidence_intervals(
105105

106106
return vec_dte_lower_simple, vec_dte_upper_simple
107107
else:
108-
raise RuntimeError(f"Invalid variance type was speficied: {variance_type}")
108+
raise ValueError(f"Invalid variance type was speficied: {variance_type}")
109109

110110

111111
def find_le(array: np.ndarray, threshold):

tests/test_adjusted_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_prediction_fail_before_fit(self):
1818
subject = AdjustedDistributionEstimator(base_model)
1919

2020
# Act, Assert
21-
with self.assertRaises(RuntimeError) as cm:
21+
with self.assertRaises(ValueError) as cm:
2222
subject.predict(D, Y)
2323
self.assertEqual(
2424
str(cm.exception),
@@ -35,7 +35,7 @@ def test_fit_fail_invalid_input(self):
3535
subject = AdjustedDistributionEstimator(base_model)
3636

3737
# Act, Assert
38-
with self.assertRaises(RuntimeError) as cm:
38+
with self.assertRaises(ValueError) as cm:
3939
subject.fit(X, D, Y)
4040
self.assertEqual(
4141
str(cm.exception),

tests/test_plot.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class TestPlot(unittest.TestCase):
99
def test_plot(self, mock_plt):
1010
# Arrange
1111
x_values = np.array([1, 2, 3, 4, 5])
12-
y_values = np.array([1, 2, 3, 4, 5])
12+
means = np.array([1, 2, 3, 4, 5])
1313
upper_bands = np.array([2, 3, 4, 5, 6])
1414
lower_bands = np.array([0, 1, 2, 3, 4])
1515
mock_ax = MagicMock()
@@ -18,12 +18,13 @@ def test_plot(self, mock_plt):
1818
# Act
1919
result_ax = plot(
2020
x_values,
21-
y_values,
22-
upper_bands,
21+
means,
2322
lower_bands,
23+
upper_bands,
2424
title="Test Title",
2525
xlabel="X Axis",
2626
ylabel="Y Axis",
27+
chart_type="line",
2728
)
2829

2930
# Assert
@@ -34,7 +35,7 @@ def test_plot(self, mock_plt):
3435
plot_args, plot_kwargs = plot_call
3536
x_values_arg, y_values_arg = plot_args
3637
self.assertTrue(np.array_equal(x_values_arg, x_values))
37-
self.assertTrue(np.array_equal(y_values_arg, y_values))
38+
self.assertTrue(np.array_equal(y_values_arg, means))
3839
fill_between_args, fill_between_kwargs = fill_between_call
3940
x_fill, lower_fill, upper_fill = fill_between_args
4041
self.assertTrue(np.array_equal(x_fill, x_values_arg))
@@ -46,8 +47,30 @@ def test_plot(self, mock_plt):
4647
mock_ax.set_title.assert_called_once_with("Test Title")
4748
mock_ax.set_xlabel.assert_called_once_with("X Axis")
4849
mock_ax.set_ylabel.assert_called_once_with("Y Axis")
49-
mock_ax.legend.assert_called_once()
5050

51+
def test_plot_fail_unknown_chart_type(self):
52+
# Arrange
53+
x_values = np.array([1, 2, 3, 4, 5])
54+
means = np.array([1, 2, 3, 4, 5])
55+
upper_bands = np.array([2, 3, 4, 5, 6])
56+
lower_bands = np.array([0, 1, 2, 3, 4])
57+
58+
# Act, Assert
59+
with self.assertRaises(ValueError) as cm:
60+
plot(
61+
x_values,
62+
means,
63+
lower_bands,
64+
upper_bands,
65+
title="Test Title",
66+
xlabel="X Axis",
67+
ylabel="Y Axis",
68+
chart_type="other",
69+
)
70+
self.assertEqual(
71+
str(cm.exception),
72+
"Chart type other is not supported",
73+
)
5174

5275
if __name__ == "__main__":
5376
unittest.main()

tests/test_simple_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_prediction_fail_before_fit(self):
3131
subject = SimpleDistributionEstimator()
3232

3333
# Act, Assert
34-
with self.assertRaises(RuntimeError) as cm:
34+
with self.assertRaises(ValueError) as cm:
3535
subject.predict(D, Y)
3636
self.assertEqual(
3737
str(cm.exception),
@@ -47,7 +47,7 @@ def test_fit_fail_invalid_input(self):
4747
subject = SimpleDistributionEstimator()
4848

4949
# Act, Assert
50-
with self.assertRaises(RuntimeError) as cm:
50+
with self.assertRaises(ValueError) as cm:
5151
subject.fit(X, D, Y)
5252
self.assertEqual(
5353
str(cm.exception),

0 commit comments

Comments
 (0)