Skip to content

Commit 767c96a

Browse files
authored
Add LDTE/LPTE estimator (#42)
* add ldte estimator * add comment * update dependency
1 parent 6ac8f24 commit 767c96a

7 files changed

Lines changed: 983 additions & 179 deletions

File tree

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ scipy = ">=1.13.1"
1212
build = "~=1.2.1"
1313
ruff = "~=0.4.9"
1414
sphinx = "~=7.3.7"
15+
scikit-learn = "~=1.5.0"
1516

1617
[requires]
1718
python_version = "3"

Pipfile.lock

Lines changed: 340 additions & 176 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
project = "dte_adj"
1414
copyright = "2024, CyberAgent, Inc."
1515
author = "CyberAgent, Inc"
16-
release = "0.1.5"
16+
release = "0.1.6"
1717

1818
# -- General configuration ---------------------------------------------------
1919
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

dte_adj/__init__.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from scipy.stats import norm
44
from copy import deepcopy
55
from abc import ABC
6-
from .util import compute_confidence_intervals
6+
from .util import compute_confidence_intervals, compute_ldte, compute_lpte
77

88
__all__ = [
99
"SimpleDistributionEstimator",
1010
"AdjustedDistributionEstimator",
1111
"SimpleStratifiedDistributionEstimator",
1212
"AdjustedStratifiedDistributionEstimator",
13+
"SimpleLocalDistributionEstimator",
14+
"AdjustedLocalDistributionEstimator",
1315
]
1416

1517

@@ -835,3 +837,192 @@ def fit(
835837
self.strata = np.zeros(len(self.covariates))
836838

837839
return self
840+
841+
842+
class SimpleLocalDistributionEstimator(SimpleStratifiedDistributionEstimator):
843+
"""A class for computing local distribution treatment effects (LDTE) using simple empirical estimation."""
844+
845+
def __init__(self):
846+
"""
847+
Initializes the SimpleLocalDistributionEstimator.
848+
849+
Returns:
850+
SimpleLocalDistributionEstimator: An instance of the estimator.
851+
"""
852+
super().__init__()
853+
854+
def fit(
855+
self,
856+
covariates: np.ndarray,
857+
treatment_arms: np.ndarray,
858+
treatment_indicator: np.ndarray,
859+
outcomes: np.ndarray,
860+
strata: np.ndarray,
861+
) -> "SimpleLocalDistributionEstimator":
862+
"""
863+
Train the SimpleLocalDistributionEstimator.
864+
865+
Args:
866+
covariates (np.ndarray): Pre-treatment covariates.
867+
treatment_arms (np.ndarray): Treatment assignment variable (Z).
868+
treatment_indicator (np.ndarray): Treatment indicator variable (D).
869+
outcomes (np.ndarray): Scalar-valued observed outcome.
870+
strata (np.ndarray): Stratum indicators.
871+
872+
Returns:
873+
SimpleLocalDistributionEstimator: The fitted estimator.
874+
"""
875+
super().fit(covariates, treatment_arms, outcomes, strata)
876+
self.treatment_indicator = treatment_indicator
877+
878+
return self
879+
880+
def predict_ldte(
881+
self,
882+
target_treatment_arm: int,
883+
control_treatment_arm: int,
884+
locations: np.ndarray,
885+
alpha: float = 0.05,
886+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
887+
"""
888+
Compute Local Distribution Treatment Effects (LDTE).
889+
890+
Args:
891+
target_treatment_arm (int): The index of the treatment arm of the treatment group.
892+
control_treatment_arm (int): The index of the treatment arm of the control group.
893+
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
894+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
895+
896+
Returns:
897+
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
898+
- Expected LDTEs
899+
- Lower bounds
900+
- Upper bounds
901+
"""
902+
return compute_ldte(
903+
self, target_treatment_arm, control_treatment_arm, locations, alpha
904+
)
905+
906+
def predict_lpte(
907+
self,
908+
target_treatment_arm: int,
909+
control_treatment_arm: int,
910+
locations: np.ndarray,
911+
alpha: float = 0.05,
912+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
913+
"""
914+
Compute Local Probability Treatment Effects (LPTE).
915+
916+
Args:
917+
target_treatment_arm (int): The index of the treatment arm of the treatment group.
918+
control_treatment_arm (int): The index of the treatment arm of the control group.
919+
locations (np.ndarray): Scalar values to be used for computing the interval probabilities.
920+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
921+
922+
Returns:
923+
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
924+
- Expected LPTEs
925+
- Lower bounds
926+
- Upper bounds
927+
"""
928+
return compute_lpte(
929+
self, target_treatment_arm, control_treatment_arm, locations, alpha
930+
)
931+
932+
933+
class AdjustedLocalDistributionEstimator(AdjustedStratifiedDistributionEstimator):
934+
"""A class for computing local distribution treatment effects (LDTE) using adjusted estimation with ML models."""
935+
936+
def __init__(self, base_model: Any, folds=3, is_multi_task=False):
937+
"""
938+
Initializes the AdjustedLocalDistributionEstimator.
939+
940+
Args:
941+
base_model (scikit-learn estimator): The base model implementing used for conditional distribution function estimators.
942+
folds (int): The number of folds for cross-fitting.
943+
is_multi_task(bool): Whether to use multi-task learning.
944+
945+
Returns:
946+
AdjustedLocalDistributionEstimator: An instance of the estimator.
947+
"""
948+
super().__init__(base_model, folds, is_multi_task)
949+
950+
def fit(
951+
self,
952+
covariates: np.ndarray,
953+
treatment_arms: np.ndarray,
954+
treatment_indicator: np.ndarray,
955+
outcomes: np.ndarray,
956+
strata: np.ndarray,
957+
) -> "AdjustedLocalDistributionEstimator":
958+
"""
959+
Train the AdjustedLocalDistributionEstimator.
960+
961+
Args:
962+
covariates (np.ndarray): Pre-treatment covariates.
963+
treatment_arms (np.ndarray): Treatment assignment variable (Z).
964+
treatment_indicator (np.ndarray): Treatment indicator variable (D).
965+
outcomes (np.ndarray): Scalar-valued observed outcome.
966+
strata (np.ndarray): Stratum indicators.
967+
968+
Returns:
969+
AdjustedLocalDistributionEstimator: The fitted estimator.
970+
"""
971+
super().fit(covariates, treatment_arms, outcomes, strata)
972+
self.treatment_indicator = treatment_indicator
973+
974+
return self
975+
976+
def predict_ldte(
977+
self,
978+
target_treatment_arm: int,
979+
control_treatment_arm: int,
980+
locations: np.ndarray,
981+
alpha: float = 0.05,
982+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
983+
"""
984+
Compute Local Distribution Treatment Effects (LDTE).
985+
Currently, this API only supports analytical confidence interval.
986+
987+
Args:
988+
target_treatment_arm (int): The index of the treatment arm of the treatment group.
989+
control_treatment_arm (int): The index of the treatment arm of the control group.
990+
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
991+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
992+
993+
Returns:
994+
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
995+
- Expected LDTEs
996+
- Lower bounds
997+
- Upper bounds
998+
"""
999+
return compute_ldte(
1000+
self, target_treatment_arm, control_treatment_arm, locations, alpha
1001+
)
1002+
1003+
def predict_lpte(
1004+
self,
1005+
target_treatment_arm: int,
1006+
control_treatment_arm: int,
1007+
locations: np.ndarray,
1008+
alpha: float = 0.05,
1009+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
1010+
"""
1011+
Compute Local Probability Treatment Effects (LPTE).
1012+
Currently, this API only supports analytical confidence interval.
1013+
1014+
Args:
1015+
target_treatment_arm (int): The index of the treatment arm of the treatment group.
1016+
control_treatment_arm (int): The index of the treatment arm of the control group.
1017+
locations (np.ndarray): Scalar values to be used for computing the interval probabilities.
1018+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
1019+
1020+
Returns:
1021+
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
1022+
- Expected LPTEs
1023+
- Lower bounds
1024+
- Upper bounds
1025+
"""
1026+
return compute_lpte(
1027+
self, target_treatment_arm, control_treatment_arm, locations, alpha
1028+
)

0 commit comments

Comments
 (0)