|
3 | 3 | from scipy.stats import norm |
4 | 4 | from copy import deepcopy |
5 | 5 | from abc import ABC |
6 | | -from .util import compute_confidence_intervals |
| 6 | +from .util import compute_confidence_intervals, compute_ldte, compute_lpte |
7 | 7 |
|
8 | 8 | __all__ = [ |
9 | 9 | "SimpleDistributionEstimator", |
10 | 10 | "AdjustedDistributionEstimator", |
11 | 11 | "SimpleStratifiedDistributionEstimator", |
12 | 12 | "AdjustedStratifiedDistributionEstimator", |
| 13 | + "SimpleLocalDistributionEstimator", |
| 14 | + "AdjustedLocalDistributionEstimator", |
13 | 15 | ] |
14 | 16 |
|
15 | 17 |
|
@@ -835,3 +837,192 @@ def fit( |
835 | 837 | self.strata = np.zeros(len(self.covariates)) |
836 | 838 |
|
837 | 839 | return self |
| 840 | + |
| 841 | + |
| 842 | +class SimpleLocalDistributionEstimator(SimpleStratifiedDistributionEstimator): |
| 843 | + """A class for computing local distribution treatment effects (LDTE) using simple empirical estimation.""" |
| 844 | + |
| 845 | + def __init__(self): |
| 846 | + """ |
| 847 | + Initializes the SimpleLocalDistributionEstimator. |
| 848 | +
|
| 849 | + Returns: |
| 850 | + SimpleLocalDistributionEstimator: An instance of the estimator. |
| 851 | + """ |
| 852 | + super().__init__() |
| 853 | + |
| 854 | + def fit( |
| 855 | + self, |
| 856 | + covariates: np.ndarray, |
| 857 | + treatment_arms: np.ndarray, |
| 858 | + treatment_indicator: np.ndarray, |
| 859 | + outcomes: np.ndarray, |
| 860 | + strata: np.ndarray, |
| 861 | + ) -> "SimpleLocalDistributionEstimator": |
| 862 | + """ |
| 863 | + Train the SimpleLocalDistributionEstimator. |
| 864 | +
|
| 865 | + Args: |
| 866 | + covariates (np.ndarray): Pre-treatment covariates. |
| 867 | + treatment_arms (np.ndarray): Treatment assignment variable (Z). |
| 868 | + treatment_indicator (np.ndarray): Treatment indicator variable (D). |
| 869 | + outcomes (np.ndarray): Scalar-valued observed outcome. |
| 870 | + strata (np.ndarray): Stratum indicators. |
| 871 | +
|
| 872 | + Returns: |
| 873 | + SimpleLocalDistributionEstimator: The fitted estimator. |
| 874 | + """ |
| 875 | + super().fit(covariates, treatment_arms, outcomes, strata) |
| 876 | + self.treatment_indicator = treatment_indicator |
| 877 | + |
| 878 | + return self |
| 879 | + |
| 880 | + def predict_ldte( |
| 881 | + self, |
| 882 | + target_treatment_arm: int, |
| 883 | + control_treatment_arm: int, |
| 884 | + locations: np.ndarray, |
| 885 | + alpha: float = 0.05, |
| 886 | + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 887 | + """ |
| 888 | + Compute Local Distribution Treatment Effects (LDTE). |
| 889 | +
|
| 890 | + Args: |
| 891 | + target_treatment_arm (int): The index of the treatment arm of the treatment group. |
| 892 | + control_treatment_arm (int): The index of the treatment arm of the control group. |
| 893 | + locations (np.ndarray): Scalar values to be used for computing the cumulative distribution. |
| 894 | + alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. |
| 895 | +
|
| 896 | + Returns: |
| 897 | + Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: |
| 898 | + - Expected LDTEs |
| 899 | + - Lower bounds |
| 900 | + - Upper bounds |
| 901 | + """ |
| 902 | + return compute_ldte( |
| 903 | + self, target_treatment_arm, control_treatment_arm, locations, alpha |
| 904 | + ) |
| 905 | + |
| 906 | + def predict_lpte( |
| 907 | + self, |
| 908 | + target_treatment_arm: int, |
| 909 | + control_treatment_arm: int, |
| 910 | + locations: np.ndarray, |
| 911 | + alpha: float = 0.05, |
| 912 | + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 913 | + """ |
| 914 | + Compute Local Probability Treatment Effects (LPTE). |
| 915 | +
|
| 916 | + Args: |
| 917 | + target_treatment_arm (int): The index of the treatment arm of the treatment group. |
| 918 | + control_treatment_arm (int): The index of the treatment arm of the control group. |
| 919 | + locations (np.ndarray): Scalar values to be used for computing the interval probabilities. |
| 920 | + alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. |
| 921 | +
|
| 922 | + Returns: |
| 923 | + Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: |
| 924 | + - Expected LPTEs |
| 925 | + - Lower bounds |
| 926 | + - Upper bounds |
| 927 | + """ |
| 928 | + return compute_lpte( |
| 929 | + self, target_treatment_arm, control_treatment_arm, locations, alpha |
| 930 | + ) |
| 931 | + |
| 932 | + |
| 933 | +class AdjustedLocalDistributionEstimator(AdjustedStratifiedDistributionEstimator): |
| 934 | + """A class for computing local distribution treatment effects (LDTE) using adjusted estimation with ML models.""" |
| 935 | + |
| 936 | + def __init__(self, base_model: Any, folds=3, is_multi_task=False): |
| 937 | + """ |
| 938 | + Initializes the AdjustedLocalDistributionEstimator. |
| 939 | +
|
| 940 | + Args: |
| 941 | + base_model (scikit-learn estimator): The base model implementing used for conditional distribution function estimators. |
| 942 | + folds (int): The number of folds for cross-fitting. |
| 943 | + is_multi_task(bool): Whether to use multi-task learning. |
| 944 | +
|
| 945 | + Returns: |
| 946 | + AdjustedLocalDistributionEstimator: An instance of the estimator. |
| 947 | + """ |
| 948 | + super().__init__(base_model, folds, is_multi_task) |
| 949 | + |
| 950 | + def fit( |
| 951 | + self, |
| 952 | + covariates: np.ndarray, |
| 953 | + treatment_arms: np.ndarray, |
| 954 | + treatment_indicator: np.ndarray, |
| 955 | + outcomes: np.ndarray, |
| 956 | + strata: np.ndarray, |
| 957 | + ) -> "AdjustedLocalDistributionEstimator": |
| 958 | + """ |
| 959 | + Train the AdjustedLocalDistributionEstimator. |
| 960 | +
|
| 961 | + Args: |
| 962 | + covariates (np.ndarray): Pre-treatment covariates. |
| 963 | + treatment_arms (np.ndarray): Treatment assignment variable (Z). |
| 964 | + treatment_indicator (np.ndarray): Treatment indicator variable (D). |
| 965 | + outcomes (np.ndarray): Scalar-valued observed outcome. |
| 966 | + strata (np.ndarray): Stratum indicators. |
| 967 | +
|
| 968 | + Returns: |
| 969 | + AdjustedLocalDistributionEstimator: The fitted estimator. |
| 970 | + """ |
| 971 | + super().fit(covariates, treatment_arms, outcomes, strata) |
| 972 | + self.treatment_indicator = treatment_indicator |
| 973 | + |
| 974 | + return self |
| 975 | + |
| 976 | + def predict_ldte( |
| 977 | + self, |
| 978 | + target_treatment_arm: int, |
| 979 | + control_treatment_arm: int, |
| 980 | + locations: np.ndarray, |
| 981 | + alpha: float = 0.05, |
| 982 | + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 983 | + """ |
| 984 | + Compute Local Distribution Treatment Effects (LDTE). |
| 985 | + Currently, this API only supports analytical confidence interval. |
| 986 | +
|
| 987 | + Args: |
| 988 | + target_treatment_arm (int): The index of the treatment arm of the treatment group. |
| 989 | + control_treatment_arm (int): The index of the treatment arm of the control group. |
| 990 | + locations (np.ndarray): Scalar values to be used for computing the cumulative distribution. |
| 991 | + alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. |
| 992 | +
|
| 993 | + Returns: |
| 994 | + Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: |
| 995 | + - Expected LDTEs |
| 996 | + - Lower bounds |
| 997 | + - Upper bounds |
| 998 | + """ |
| 999 | + return compute_ldte( |
| 1000 | + self, target_treatment_arm, control_treatment_arm, locations, alpha |
| 1001 | + ) |
| 1002 | + |
| 1003 | + def predict_lpte( |
| 1004 | + self, |
| 1005 | + target_treatment_arm: int, |
| 1006 | + control_treatment_arm: int, |
| 1007 | + locations: np.ndarray, |
| 1008 | + alpha: float = 0.05, |
| 1009 | + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 1010 | + """ |
| 1011 | + Compute Local Probability Treatment Effects (LPTE). |
| 1012 | + Currently, this API only supports analytical confidence interval. |
| 1013 | +
|
| 1014 | + Args: |
| 1015 | + target_treatment_arm (int): The index of the treatment arm of the treatment group. |
| 1016 | + control_treatment_arm (int): The index of the treatment arm of the control group. |
| 1017 | + locations (np.ndarray): Scalar values to be used for computing the interval probabilities. |
| 1018 | + alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. |
| 1019 | +
|
| 1020 | + Returns: |
| 1021 | + Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: |
| 1022 | + - Expected LPTEs |
| 1023 | + - Lower bounds |
| 1024 | + - Upper bounds |
| 1025 | + """ |
| 1026 | + return compute_lpte( |
| 1027 | + self, target_treatment_arm, control_treatment_arm, locations, alpha |
| 1028 | + ) |
0 commit comments