33from scipy .stats import norm
44import math
55from copy import deepcopy
6+ from .util import compute_confidence_intervals , find_le
67
78
89class DistributionFunctionMixin (object ):
@@ -33,15 +34,15 @@ def predict_dte(
3334 target_treatment_arm (int): The index of the treatment arm of the treatment group.
3435 control_treatment_arm (int): The index of the treatment arm of the control group.
3536 locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
36- alpha (float, optional): Significance level of the confidence band . Defaults to 0.05.
37+ alpha (float, optional): Significance level of the confidence bound . Defaults to 0.05.
3738 variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
3839 n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
3940
4041 Returns:
4142 Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
4243 - Expected DTEs
43- - Upper bands
44- - Lower bands
44+ - Upper bounds
45+ - Lower bounds
4546 """
4647 return self ._compute_dtes (
4748 target_treatment_arm ,
@@ -68,14 +69,14 @@ def predict_pte(
6869 control_treatment_arm (int): The index of the treatment arm of the control group.
6970 locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
7071 width (float): The width of each outcome interval.
71- alpha (float, optional): Significance level of the confidence band . Defaults to 0.05.
72+ alpha (float, optional): Significance level of the confidence bound . Defaults to 0.05.
7273 variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
7374
7475 Returns:
7576 Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
7677 - Expected PTEs
77- - Upper bands
78- - Lower bands
78+ - Upper bounds
79+ - Lower bounds
7980 """
8081 return self ._compute_ptes (
8182 target_treatment_arm ,
@@ -102,14 +103,14 @@ def predict_qte(
102103 target_treatment_arm (int): The index of the treatment arm of the treatment group.
103104 control_treatment_arm (int): The index of the treatment arm of the control group.
104105 quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1 * i for i in range(1, 10)].
105- alpha (float, optional): Significance level of the confidence band . Defaults to 0.05.
106+ alpha (float, optional): Significance level of the confidence bound . Defaults to 0.05.
106107 n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
107108
108109 Returns:
109110 Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
110111 - Expected QTEs
111- - Upper bands
112- - Lower bands
112+ - Upper bounds
113+ - Lower bounds
113114 """
114115 qte = self ._compute_qtes (
115116 target_treatment_arm ,
@@ -170,7 +171,7 @@ def _compute_dtes(
170171
171172 mat_indicator = (self .outcome [:, np .newaxis ] <= locations ).astype (int )
172173
173- lower_band , upper_band = compute_confidence_intervals (
174+ lower_bound , upper_bound = compute_confidence_intervals (
174175 vec_y = self .outcome ,
175176 vec_d = self .treatment_arm ,
176177 vec_loc = locations ,
@@ -188,8 +189,8 @@ def _compute_dtes(
188189
189190 return (
190191 dte ,
191- lower_band ,
192- upper_band ,
192+ lower_bound ,
193+ upper_bound ,
193194 )
194195
195196 def _compute_ptes (
@@ -248,7 +249,7 @@ def _compute_ptes(
248249 int
249250 )
250251
251- lower_band , upper_band = compute_confidence_intervals (
252+ lower_bound , upper_bound = compute_confidence_intervals (
252253 vec_y = self .outcome ,
253254 vec_d = self .treatment_arm ,
254255 vec_loc = locations ,
@@ -266,8 +267,8 @@ def _compute_ptes(
266267
267268 return (
268269 pte ,
269- lower_band ,
270- upper_band ,
270+ lower_bound ,
271+ upper_bound ,
271272 )
272273
273274 def _compute_qtes (
@@ -326,28 +327,6 @@ def _compute_cumulative_distribution(
326327 raise NotImplementedError ()
327328
328329
329- def find_le (array : np .ndarray , threshold ):
330- """Find the rightmost value less than or equal to threshold in a sorted array
331-
332- Args:
333- array (np.ndarray): The sorted array to search in.
334- threshold (float): The threshold value.
335-
336- Returns:
337- int: The index where the value first exceeds the threshold.
338- """
339- low , high = 0 , array .shape [0 ] - 1
340- result = - 1
341- while low <= high :
342- mid = (low + high ) // 2
343- if array [mid ] <= threshold :
344- result = mid
345- low = mid + 1
346- else :
347- high = mid - 1
348- return result
349-
350-
351330class SimpleDistributionEstimator (DistributionFunctionMixin ):
352331 """A class for computing the empirical distribution function and the distributional parameters
353332 based on the distribution function.
@@ -564,108 +543,3 @@ def _compute_cumulative_distribution(
564543 )
565544 return cumulative_distribution , superset_prediction
566545
567-
568- def compute_confidence_intervals (
569- vec_y : np .ndarray ,
570- vec_d : np .ndarray ,
571- vec_loc : np .ndarray ,
572- mat_y_u : np .ndarray ,
573- vec_prediction_target : np .ndarray ,
574- vec_prediction_control : np .ndarray ,
575- mat_entire_predictions_target : np .ndarray ,
576- mat_entire_predictions_control : np .ndarray ,
577- ind_target : int ,
578- ind_control : int ,
579- alpha : 0.05 ,
580- variance_type = "moment" ,
581- n_bootstrap = 500 ,
582- ):
583- """Computes the confidence intervals of distribution parameters.
584-
585- Args:
586- vec_y (np.ndarray): Outcome variable vector.
587- vec_d (np.ndarray): Treatment indicator vector.
588- vec_loc (np.ndarray): Locations where the distribution parameters are estimated.
589- mat_y_u (np.ndarray): Indicator function for 1{Y⩽y}. Shape is n_obs * n_loc.
590- vec_prediction_target (np.ndarray): Estimated values from the conditional model for the treatment group.
591- vec_prediction_control (np.ndarray): Estimated values from the conditional model for the control group.
592- mat_entire_predictions_target (np.ndarray): Prediction of the conditional distribution estimator for target group.
593- mat_entire_predictions_control (np.ndarray): Prediction of the conditional distribution estimator for control group.
594- ind_target (int): Index of the target treatment indicator.
595- ind_control (int): Index of the control treatment indicator.
596- alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
597- variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
598- n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
599-
600- Returns:
601- Tuple[np.ndarray, np.ndarray]: A tuple containing:
602- - np.ndarray: lower band.
603- - np.ndarray: upper band.
604- """
605- num_obs = vec_y .shape [0 ]
606- n_loc = vec_loc .shape [0 ]
607- mat_d = np .tile (vec_d , (n_loc , 1 )).T
608- vec_dte = vec_prediction_target - vec_prediction_control
609- mat_dte = np .tile (vec_dte , (num_obs , 1 ))
610-
611- num_target = (vec_d == ind_target ).sum ()
612- num_control = (vec_d == ind_control ).sum ()
613- influence_function = (
614- num_obs
615- / num_target
616- * (mat_d == ind_target )
617- * (mat_y_u - mat_entire_predictions_target )
618- + mat_entire_predictions_target
619- - num_obs
620- / num_control
621- * (mat_d == ind_control )
622- * (mat_y_u - mat_entire_predictions_control )
623- - mat_entire_predictions_control
624- - mat_dte
625- )
626-
627- omega = (influence_function ** 2 ).mean (axis = 0 )
628-
629- if variance_type == "moment" :
630- vec_dte_lower_moment = vec_dte + norm .ppf (alpha / 2 ) * np .sqrt (omega / num_obs )
631- vec_dte_upper_moment = vec_dte + norm .ppf (1 - alpha / 2 ) * np .sqrt (
632- omega / num_obs
633- )
634- return vec_dte_lower_moment , vec_dte_upper_moment
635- elif variance_type == "uniform" :
636- tstats = np .zeros ((n_bootstrap , len (vec_loc )))
637- boot_draw = np .zeros ((n_bootstrap , len (vec_loc )))
638-
639- for b in range (n_bootstrap ):
640- eta1 = np .random .normal (0 , 1 , num_obs )
641- eta2 = np .random .normal (0 , 1 , num_obs )
642- xi = eta1 / np .sqrt (2 ) + (eta2 ** 2 - 1 ) / 2
643-
644- boot_draw [b , :] = (
645- 1 / num_obs * np .sum (xi [:, np .newaxis ] * influence_function , axis = 0 )
646- )
647-
648- tstats = np .abs (boot_draw )[:, :- 1 ] / np .sqrt (omega [:- 1 ] / num_obs )
649- max_tstats = np .max (tstats , axis = 1 )
650- quantile_max_tstats = np .quantile (max_tstats , 1 - alpha )
651-
652- vec_dte_lower_boot = vec_dte - quantile_max_tstats * np .sqrt (omega / num_obs )
653- vec_dte_upper_boot = vec_dte + quantile_max_tstats * np .sqrt (omega / num_obs )
654- return vec_dte_lower_boot , vec_dte_upper_boot
655- elif variance_type == "simple" :
656- w_target = num_obs / num_target
657- w_control = num_obs / num_control
658- vec_dte_var = w_target * (
659- vec_prediction_target * (1 - vec_prediction_target )
660- ) + w_control * vec_prediction_control * (1 - vec_prediction_control )
661-
662- vec_dte_lower_simple = vec_dte + norm .ppf (alpha / 2 ) / np .sqrt (
663- num_obs
664- ) * np .sqrt (vec_dte_var )
665- vec_dte_upper_simple = vec_dte + norm .ppf (1 - alpha / 2 ) / np .sqrt (
666- num_obs
667- ) * np .sqrt (vec_dte_var )
668-
669- return vec_dte_lower_simple , vec_dte_upper_simple
670- else :
671- raise RuntimeError (f"Invalid variance type was speficied: { variance_type } " )
0 commit comments