Skip to content

Commit ff82f0c

Browse files
authored
Merge pull request #14 from CyberAgentAILab/chore/unittest
Improve unit tests
2 parents 8ea9b93 + 2c1f9d7 commit ff82f0c

5 files changed

Lines changed: 310 additions & 172 deletions

File tree

dte_adj/__init__.py

Lines changed: 59 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from typing import Tuple
33
from scipy.stats import norm
44
from copy import deepcopy
5-
from .util import compute_confidence_intervals, find_le
5+
from abc import ABC
6+
from .util import compute_confidence_intervals
67

78
__all__ = ["SimpleDistributionEstimator", "AdjustedDistributionEstimator"]
89

910

10-
class DistributionFunctionMixin(object):
11+
class DistributionEstimatorBase(ABC):
1112
"""A mixin including several convenience functions to compute and display distribution functions."""
1213

1314
def __init__(self):
@@ -311,55 +312,18 @@ def find_quantile(quantile, arm):
311312

312313
return result
313314

314-
def predict(self, treatment_arms: np.ndarray, outcomes: np.ndarray) -> np.ndarray:
315-
"""Compute cumulative distribution values.
316-
317-
Args:
318-
treatment_arms (np.ndarray): The index of the treatment arm.
319-
outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
320-
321-
Returns:
322-
np.ndarray: Estimated cumulative distribution values for the input.
323-
"""
324-
raise NotImplementedError()
325-
326-
def _compute_cumulative_distribution(
327-
self,
328-
target_treatment_arms: np.ndarray,
329-
locations: np.ndarray,
330-
confoundings: np.ndarray,
331-
treatment_arms: np.ndarray,
332-
outcomes: np.array,
333-
) -> np.ndarray:
334-
"""Compute the cumulative distribution values."""
335-
raise NotImplementedError()
336-
337-
338-
class SimpleDistributionEstimator(DistributionFunctionMixin):
339-
"""A class for computing the empirical distribution function and the distributional parameters
340-
based on the distribution function.
341-
"""
342-
343-
def __init__(self):
344-
"""Initializes the SimpleDistributionEstimator.
345-
346-
Returns:
347-
SimpleDistributionEstimator: An instance of the estimator.
348-
"""
349-
super().__init__()
350-
351315
def fit(
352316
self, confoundings: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
353-
) -> "SimpleDistributionEstimator":
354-
"""Train the SimpleDistributionEstimator.
317+
) -> "DistributionEstimatorBase":
318+
"""Train the DistributionEstimatorBase.
355319
356320
Args:
357321
confoundings (np.ndarray): Pre-treatment covariates.
358322
treatment_arms (np.ndarray): The index of the treatment arm.
359323
outcomes (np.ndarray): Scalar-valued observed outcome.
360324
361325
Returns:
362-
SimpleDistributionEstimator: The fitted estimator.
326+
DistributionEstimatorBase: The fitted estimator.
363327
"""
364328
if confoundings.shape[0] != treatment_arms.shape[0]:
365329
raise ValueError(
@@ -380,7 +344,7 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
380344
381345
Args:
382346
treatment_arms (np.ndarray): The index of the treatment arm.
383-
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
347+
outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
384348
385349
Returns:
386350
np.ndarray: Estimated cumulative distribution values for the input.
@@ -390,6 +354,13 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
390354
"This estimator has not been trained yet. Please call fit first"
391355
)
392356

357+
unincluded_arms = set(treatment_arms) - set(self.treatment_arms)
358+
359+
if len(unincluded_arms) > 0:
360+
raise ValueError(
361+
f"This treatment_arms argument contains arms not included in the training data: {unincluded_arms}"
362+
)
363+
393364
return self._compute_cumulative_distribution(
394365
treatment_arms,
395366
locations,
@@ -398,6 +369,31 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
398369
self.outcomes,
399370
)[0]
400371

372+
def _compute_cumulative_distribution(
373+
self,
374+
target_treatment_arms: np.ndarray,
375+
locations: np.ndarray,
376+
confoundings: np.ndarray,
377+
treatment_arms: np.ndarray,
378+
outcomes: np.array,
379+
) -> np.ndarray:
380+
"""Compute the cumulative distribution values."""
381+
raise NotImplementedError()
382+
383+
384+
class SimpleDistributionEstimator(DistributionEstimatorBase):
385+
"""A class for computing the empirical distribution function and the distributional parameters
386+
based on the distribution function.
387+
"""
388+
389+
def __init__(self):
390+
"""Initializes the SimpleDistributionEstimator.
391+
392+
Returns:
393+
SimpleDistributionEstimator: An instance of the estimator.
394+
"""
395+
super().__init__()
396+
401397
def _compute_cumulative_distribution(
402398
self,
403399
target_treatment_arms: np.ndarray,
@@ -432,12 +428,12 @@ def _compute_cumulative_distribution(
432428
cumulative_distribution = np.zeros(locations.shape)
433429
for i, (outcome, arm) in enumerate(zip(locations, target_treatment_arms)):
434430
cumulative_distribution[i] = (
435-
find_le(d_outcome[arm], outcome) + 1
431+
np.searchsorted(d_outcome[arm], outcome, side="right")
436432
) / d_outcome[arm].shape[0]
437433
return cumulative_distribution, np.zeros((n_obs, n_loc))
438434

439435

440-
class AdjustedDistributionEstimator(DistributionFunctionMixin):
436+
class AdjustedDistributionEstimator(DistributionEstimatorBase):
441437
"""A class is for estimating the adjusted distribution function and computing the Distributional parameters based on the trained conditional estimator."""
442438

443439
def __init__(self, base_model, folds=3):
@@ -450,60 +446,16 @@ def __init__(self, base_model, folds=3):
450446
Returns:
451447
AdjustedDistributionEstimator: An instance of the estimator.
452448
"""
449+
if (not hasattr(base_model, "predict")) and (
450+
not hasattr(base_model, "predict_proba")
451+
):
452+
raise ValueError(
453+
"Base model should implement either predict_proba or predict"
454+
)
453455
self.base_model = base_model
454456
self.folds = folds
455457
super().__init__()
456458

457-
def fit(
458-
self, confoundings: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
459-
) -> "AdjustedDistributionEstimator":
460-
"""Train the AdjustedDistributionEstimator.
461-
462-
Args:
463-
confoundings (np.ndarray): Pre-treatment covariates.
464-
treatment_arms (np.ndarray): The index of the treatment arm.
465-
outcomes (np.ndarray): Scalar-valued observed outcome.
466-
467-
Returns:
468-
AdjustedDistributionEstimator: The fitted estimator.
469-
"""
470-
if confoundings.shape[0] != treatment_arms.shape[0]:
471-
raise ValueError(
472-
"The shape of confounding and treatment_arm should be same"
473-
)
474-
475-
if confoundings.shape[0] != outcomes.shape[0]:
476-
raise ValueError("The shape of confounding and outcome should be same")
477-
478-
self.confoundings = confoundings
479-
self.treatment_arms = treatment_arms
480-
self.outcomes = outcomes
481-
482-
return self
483-
484-
def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarray:
485-
"""Compute cumulative distribution values.
486-
487-
Args:
488-
treatment_arms (np.ndarray): The index of the treatment arm.
489-
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
490-
491-
Returns:
492-
np.ndarray: Estimated cumulative distribution values for the input.
493-
"""
494-
if self.outcomes is None:
495-
raise ValueError(
496-
"This estimator has not been trained yet. Please call fit first"
497-
)
498-
499-
return self._compute_cumulative_distribution(
500-
treatment_arms,
501-
locations,
502-
self.confoundings,
503-
self.treatment_arms,
504-
self.outcomes,
505-
)[0]
506-
507459
def _compute_cumulative_distribution(
508460
self,
509461
target_treatment_arms: np.ndarray,
@@ -548,13 +500,19 @@ def _compute_cumulative_distribution(
548500
continue
549501
model = deepcopy(self.base_model)
550502
model.fit(confounding_train, binominal_train)
551-
subset_prediction[subset_mask] = model.predict_proba(confounding_fit)[
552-
:, 1
553-
]
554-
superset_prediction[superset_mask, i] = model.predict_proba(
555-
confoundings[superset_mask]
556-
)[:, 1]
503+
subset_prediction[subset_mask] = self._compute_model_prediction(
504+
model, confounding_fit
505+
)
506+
superset_prediction[superset_mask, i] = self._compute_model_prediction(
507+
model, confoundings[superset_mask]
508+
)
557509
cumulative_distribution[i] = (
558510
cdf - subset_prediction.mean() + superset_prediction[:, i].mean()
559511
)
560512
return cumulative_distribution, superset_prediction
513+
514+
def _compute_model_prediction(self, model, confoundings: np.ndarray) -> np.ndarray:
515+
if hasattr(model, "predict_proba"):
516+
return model.predict_proba(confoundings)[:, 1]
517+
else:
518+
return model.predict(confoundings)

dte_adj/util.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy.stats import norm
3+
from typing import Tuple
34

45

56
def compute_confidence_intervals(
@@ -16,7 +17,7 @@ def compute_confidence_intervals(
1617
alpha: 0.05,
1718
variance_type="moment",
1819
n_bootstrap=500,
19-
):
20+
) -> Tuple[np.ndarray, np.ndarray]:
2021
"""Computes the confidence intervals of distribution parameters.
2122
2223
Args:
@@ -106,25 +107,3 @@ def compute_confidence_intervals(
106107
return vec_dte_lower_simple, vec_dte_upper_simple
107108
else:
108109
raise ValueError(f"Invalid variance type was speficied: {variance_type}")
109-
110-
111-
def find_le(array: np.ndarray, threshold):
112-
"""Find the rightmost value less than or equal to threshold in a sorted array
113-
114-
Args:
115-
array (np.ndarray): The sorted array to search in.
116-
threshold (float): The threshold value.
117-
118-
Returns:
119-
int: The index where the value first exceeds the threshold.
120-
"""
121-
low, high = 0, array.shape[0] - 1
122-
result = -1
123-
while low <= high:
124-
mid = (low + high) // 2
125-
if array[mid] <= threshold:
126-
result = mid
127-
low = mid + 1
128-
else:
129-
high = mid - 1
130-
return result

tests/test_adjusted_estimator.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,25 @@
55

66

77
class TestAdjustedEstimator(unittest.TestCase):
8-
def test_prediction_success(self):
9-
# TODO!
10-
return
8+
def setUp(self):
9+
base_model = MagicMock()
10+
base_model.predict_proba.side_effect = lambda x, y: x
11+
self.estimator = AdjustedDistributionEstimator(base_model, folds=1)
12+
self.confoundings = np.zeros((20, 5))
13+
self.treatment_arms = np.hstack([np.zeros(10), np.ones(10)])
14+
self.outcomes = np.arange(20)
15+
self.estimator.fit(self.confoundings, self.treatment_arms, self.outcomes)
16+
17+
def test_init_fail_incorrect_base_model(self):
18+
# Act, Assert
19+
with self.assertRaises(ValueError) as cm:
20+
AdjustedDistributionEstimator("dummy")
21+
self.assertEqual(
22+
str(cm.exception),
23+
"Base model should implement either predict_proba or predict",
24+
)
1125

12-
def test_prediction_fail_before_fit(self):
26+
def test_predict_fail_before_fit(self):
1327
# Arrange
1428
D = np.zeros(20)
1529
D[:10] = 1
@@ -41,3 +55,32 @@ def test_fit_fail_invalid_input(self):
4155
str(cm.exception),
4256
"The shape of confounding and treatment_arm should be same",
4357
)
58+
59+
def test_compute_cumulative_distribution(self):
60+
# Arrange
61+
mock_model = self.estimator.base_model
62+
mock_model.predict_proba.side_effect = lambda x: np.ones((x.shape[0], 2)) * 0.5
63+
target_treatment_arms = np.zeros(10)
64+
locations = np.arange(10)
65+
66+
# Act
67+
cumulative_distribution, superset_prediction = (
68+
self.estimator._compute_cumulative_distribution(
69+
target_treatment_arms,
70+
locations,
71+
self.confoundings,
72+
self.treatment_arms,
73+
self.outcomes,
74+
)
75+
)
76+
77+
# Assert
78+
self.assertEqual(cumulative_distribution.shape, (10,))
79+
self.assertEqual(superset_prediction.shape, (20, 10))
80+
81+
for i in range(10):
82+
self.assertAlmostEqual(cumulative_distribution[i], (i + 1) / 10, places=2)
83+
84+
for i in range(20):
85+
for j in range(10):
86+
self.assertAlmostEqual(superset_prediction[i, j], 0.5, places=2)

0 commit comments

Comments
 (0)