Skip to content

Commit f8c0c1b

Browse files
committed
feat: change param names and add tests
1 parent 7e74d98 commit f8c0c1b

3 files changed

Lines changed: 131 additions & 17 deletions

File tree

dte_adj/__init__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def predict_dte(
2424
control_treatment_arm: int,
2525
outcomes: np.ndarray,
2626
alpha: float = 0.05,
27-
variance_type="pointwise",
27+
variance_type="moment",
2828
n_bootstrap=500,
2929
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
3030
"""Compute DTE based on the estimator for the distribution function.
@@ -34,7 +34,7 @@ def predict_dte(
3434
control_treatment_arm (int): The index of the treatment arm of the control group.
3535
outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
3636
alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
37-
variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are pointwise, analytic, and uniform.
37+
variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment, analytic, and uniform.
3838
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
3939
4040
Returns:
@@ -196,12 +196,18 @@ def _compute_expected_qtes(
196196
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
197197
"""Compute expected QTEs."""
198198
result = np.zeros(quantiles.shape)
199+
treatment_cumulative = self.predict(
200+
np.full(self.outcome.shape, target_treatment_arm), self.outcome
201+
)
202+
control_cumulative = self.predict(
203+
np.full(self.outcome.shape, control_treatment_arm), self.outcome
204+
)
199205
for i, q in enumerate(quantiles):
200-
treatment_quantile = self.outcome[target_treatment_arm][
201-
math.floor(self.outcome[i].shape[0] * q)
206+
treatment_quantile = treatment_cumulative[
207+
math.floor(treatment_cumulative.shape[0] * q)
202208
]
203-
control_quantile = self.outcome[control_treatment_arm][
204-
math.floor(self.outcome[i].shape[0] * q)
209+
control_quantile = control_cumulative[
210+
math.floor(control_cumulative.shape[0] * q)
205211
]
206212
result[i] = treatment_quantile - control_quantile
207213

@@ -481,7 +487,7 @@ def compute_dte_confidence_intervals(
481487
ind_target: int,
482488
ind_control: int,
483489
alpha: 0.05,
484-
variance_type="pointwise",
490+
variance_type="moment",
485491
n_bootstrap=500,
486492
):
487493
"""Computes the confidence intervals of DTE.
@@ -497,7 +503,7 @@ def compute_dte_confidence_intervals(
497503
ind_target (int): Index of the target treatment indicator.
498504
ind_control (int): Index of the control treatment indicator.
499505
alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
500-
variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are pointwise, analytic, and uniform.
506+
variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment, analytic, and uniform.
501507
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
502508
503509
Returns:
@@ -530,7 +536,7 @@ def compute_dte_confidence_intervals(
530536

531537
omega = (influence_function**2).mean(axis=0)
532538

533-
if variance_type == "pointwise":
539+
if variance_type == "moment":
534540
vec_dte_lower_moment = vec_dte + norm.ppf(alpha / 2) * np.sqrt(omega / num_obs)
535541
vec_dte_upper_moment = vec_dte + norm.ppf(1 - alpha / 2) * np.sqrt(
536542
omega / num_obs

example/example.ipynb

Lines changed: 73 additions & 8 deletions
Large diffs are not rendered by default.

tests/test_adjusted_estimator.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
import numpy as np
3+
from dte_adj import AdjustedDistributionEstimator
4+
from unittest.mock import MagicMock
5+
6+
7+
class TestAdjustedEstimator(unittest.TestCase):
8+
def test_prediction_success(self):
9+
# TODO!
10+
return
11+
12+
def test_prediction_fail_before_fit(self):
13+
# Arrange
14+
D = np.zeros(20)
15+
D[:10] = 1
16+
Y = np.arange(20)
17+
base_model = MagicMock()
18+
subject = AdjustedDistributionEstimator(base_model)
19+
20+
# Act, Assert
21+
with self.assertRaises(RuntimeError) as cm:
22+
subject.predict(D, Y)
23+
self.assertEqual(
24+
str(cm.exception),
25+
"This estimator has not been trained yet. Please call fit first",
26+
)
27+
28+
def test_fit_fail_invalid_input(self):
29+
# Arrange
30+
X = np.arange(20)
31+
D = np.zeros(10)
32+
D[:10] = 1
33+
Y = np.arange(20)
34+
base_model = MagicMock()
35+
subject = AdjustedDistributionEstimator(base_model)
36+
37+
# Act, Assert
38+
with self.assertRaises(RuntimeError) as cm:
39+
subject.fit(X, D, Y)
40+
self.assertEqual(
41+
str(cm.exception),
42+
"The shape of confounding and treatment_arm should be same",
43+
)

0 commit comments

Comments
 (0)