Skip to content

Commit c43869d

Browse files
committed
add e2e for simple estimators
1 parent 49b6a4b commit c43869d

2 files changed

Lines changed: 191 additions & 116 deletions

File tree

tests/test_adjusted_estimator.py

Lines changed: 0 additions & 115 deletions
This file was deleted.

tests/test_simple_estimator.py

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,47 @@
11
import unittest
22
import numpy as np
3-
from dte_adj import SimpleDistributionEstimator
3+
from unittest.mock import patch, MagicMock
4+
from sklearn.linear_model import LogisticRegression
5+
from dte_adj import SimpleDistributionEstimator, AdjustedDistributionEstimator
6+
7+
8+
def generate_data(n, d_x=100, rho=0.5):
9+
"""
10+
Generate data according to the described data generating process (DGP).
11+
12+
Args:
13+
n (int): Number of samples.
14+
d_x (int): Number of covariates. Default is 100.
15+
rho (float): Success probability for the Bernoulli distribution. Default is 0.5.
16+
17+
Returns:
18+
X (np.ndarray): Covariates matrix of shape (n, d_x).
19+
D (np.ndarray): Treatment variable array of shape (n,).
20+
Y (np.ndarray): Outcome variable array of shape (n,).
21+
"""
22+
# Generate covariates X from a uniform distribution on (0, 1)
23+
X = np.random.uniform(0, 1, (n, d_x))
24+
25+
# Generate treatment variable D from a Bernoulli distribution with success probability rho
26+
D = np.random.binomial(1, rho, n)
27+
28+
# Define beta_j and gamma_j according to the problem statement
29+
beta = np.zeros(d_x)
30+
gamma = np.zeros(d_x)
31+
32+
# Set the first 50 values of beta and gamma to 1
33+
beta[:50] = 1
34+
gamma[:50] = 1
35+
36+
# Compute the outcome Y
37+
U = np.random.normal(0, 1, n) # Error term
38+
linear_term = np.dot(X, beta)
39+
quadratic_term = np.dot(X**2, gamma)
40+
41+
# Outcome equation
42+
Y = 5 * D + linear_term + quadratic_term + U
43+
44+
return X, D, Y
445

546

647
class TestSimpleEstimator(unittest.TestCase):
@@ -38,3 +79,152 @@ def test_fit_invalid_shapes(self):
3879

3980
with self.assertRaises(ValueError):
4081
self.estimator.fit(self.covariates, self.treatment_arms, outcomes_invalid)
82+
83+
84+
class TestAdjustedEstimator(unittest.TestCase):
85+
def setUp(self):
86+
base_model = MagicMock()
87+
base_model.predict_proba.side_effect = lambda x, y: x
88+
self.estimator = AdjustedDistributionEstimator(base_model, folds=2)
89+
self.covariates = np.zeros((20, 5))
90+
self.treatment_arms = np.hstack([np.zeros(10), np.ones(10)])
91+
self.outcomes = np.arange(20)
92+
self.estimator.fit(self.covariates, self.treatment_arms, self.outcomes)
93+
94+
def test_init_fail_incorrect_base_model(self):
95+
# Act, Assert
96+
with self.assertRaises(ValueError) as cm:
97+
AdjustedDistributionEstimator("dummy")
98+
self.assertEqual(
99+
str(cm.exception),
100+
"Base model should implement either predict_proba or predict",
101+
)
102+
103+
def test_predict_fail_before_fit(self):
104+
# Arrange
105+
D = np.zeros(20)
106+
D[:10] = 1
107+
Y = np.arange(20)
108+
base_model = MagicMock()
109+
subject = AdjustedDistributionEstimator(base_model)
110+
111+
# Act, Assert
112+
with self.assertRaises(ValueError) as cm:
113+
subject.predict(D, Y)
114+
self.assertEqual(
115+
str(cm.exception),
116+
"This estimator has not been trained yet. Please call fit first",
117+
)
118+
119+
def test_fit_fail_invalid_input(self):
120+
# Arrange
121+
X = np.arange(20)
122+
D = np.zeros(10)
123+
D[:10] = 1
124+
Y = np.arange(20)
125+
base_model = MagicMock()
126+
subject = AdjustedDistributionEstimator(base_model)
127+
128+
# Act, Assert
129+
with self.assertRaises(ValueError) as cm:
130+
subject.fit(X, D, Y)
131+
self.assertEqual(
132+
str(cm.exception),
133+
"The shape of covariates and treatment_arm should be same",
134+
)
135+
136+
def test_compute_cumulative_distribution(self):
137+
# Arrange
138+
mock_model = self.estimator.base_model
139+
mock_model.predict_proba.side_effect = lambda x: np.ones((len(x), 2)) * 0.5
140+
target_treatment_arm = 0
141+
locations = np.arange(10)
142+
143+
# Act
144+
with patch(
145+
"numpy.random.randint",
146+
return_value=np.array([0] * 5 + [1] * 5 + [0] * 5 + [1] * 5),
147+
):
148+
cumulative_distribution, _, superset_prediction = (
149+
self.estimator._compute_cumulative_distribution(
150+
target_treatment_arm,
151+
locations,
152+
self.covariates,
153+
self.treatment_arms,
154+
self.outcomes,
155+
)
156+
)
157+
158+
# Assert
159+
self.assertEqual(cumulative_distribution.shape, (10,))
160+
self.assertEqual(superset_prediction.shape, (20, 10))
161+
162+
for i in range(10):
163+
self.assertAlmostEqual(cumulative_distribution[i], (i + 1) / 10, places=2)
164+
165+
expected_result = np.array(
166+
[
167+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
168+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
169+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
170+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
171+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
172+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
173+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
174+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
175+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
176+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
177+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
178+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
179+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
180+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
181+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 1.0],
182+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
183+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
184+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
185+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
186+
[0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
187+
]
188+
)
189+
np.testing.assert_array_almost_equal(
190+
superset_prediction, expected_result, decimal=2
191+
)
192+
193+
194+
class TestE2E(unittest.TestCase):
195+
def test_e2e(self):
196+
# Arrange
197+
X, D, Y = generate_data(n=1000)
198+
locations = np.array([np.percentile(Y, p) for p in range(10, 91, 10)])
199+
simple_estimator = SimpleDistributionEstimator()
200+
adjusted_estimator = AdjustedDistributionEstimator(LogisticRegression())
201+
202+
# Act
203+
simple_estimator.fit(X, D, Y)
204+
adjusted_estimator.fit(X, D, Y)
205+
206+
simple_dte, simple_lower_bound, simple_upper_bound = (
207+
simple_estimator.predict_dte(1, 0, locations)
208+
)
209+
adjusted_dte, adjusted_lower_bound, adjusted_upper_bound = (
210+
adjusted_estimator.predict_dte(1, 0, locations)
211+
)
212+
213+
# Assert
214+
np.testing.assert_(np.all(simple_dte < 0), "Not all values are negative")
215+
np.testing.assert_(np.all(adjusted_dte < 0), "Not all values are negative")
216+
np.testing.assert_(
217+
np.all(simple_lower_bound < simple_upper_bound),
218+
"Upper bound is less than lower bound",
219+
)
220+
np.testing.assert_(
221+
np.all(adjusted_lower_bound < adjusted_upper_bound),
222+
"Upper bound is less than lower bound",
223+
)
224+
np.testing.assert_(
225+
np.all(
226+
adjusted_upper_bound - adjusted_lower_bound
227+
< simple_upper_bound - simple_lower_bound
228+
),
229+
"Adjusted estimator does not have narrower intervals",
230+
)

0 commit comments

Comments
 (0)