@@ -24,7 +24,7 @@ def predict_dte(
2424 control_treatment_arm : int ,
2525 outcomes : np .ndarray ,
2626 alpha : float = 0.05 ,
27- variance_type = "pointwise " ,
27+ variance_type = "moment " ,
2828 n_bootstrap = 500 ,
2929 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
3030 """Compute DTE based on the estimator for the distribution function.
@@ -34,7 +34,7 @@ def predict_dte(
3434 control_treatment_arm (int): The index of the treatment arm of the control group.
3535 outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
3636 alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
37- variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are pointwise , analytic, and uniform.
37+ variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment , analytic, and uniform.
3838 n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
3939
4040 Returns:
@@ -196,12 +196,18 @@ def _compute_expected_qtes(
196196 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
197197 """Compute expected QTEs."""
198198 result = np .zeros (quantiles .shape )
199+ treatment_cumulative = self .predict (
200+ np .full (self .outcome .shape , target_treatment_arm ), self .outcome
201+ )
202+ control_cumulative = self .predict (
203+ np .full (self .outcome .shape , control_treatment_arm ), self .outcome
204+ )
199205 for i , q in enumerate (quantiles ):
200- treatment_quantile = self . outcome [ target_treatment_arm ] [
201- math .floor (self . outcome [ i ] .shape [0 ] * q )
206+ treatment_quantile = treatment_cumulative [
207+ math .floor (treatment_cumulative .shape [0 ] * q )
202208 ]
203- control_quantile = self . outcome [ control_treatment_arm ] [
204- math .floor (self . outcome [ i ] .shape [0 ] * q )
209+ control_quantile = control_cumulative [
210+ math .floor (control_cumulative .shape [0 ] * q )
205211 ]
206212 result [i ] = treatment_quantile - control_quantile
207213
@@ -481,7 +487,7 @@ def compute_dte_confidence_intervals(
481487 ind_target : int ,
482488 ind_control : int ,
483489 alpha : 0.05 ,
484- variance_type = "pointwise " ,
490+ variance_type = "moment " ,
485491 n_bootstrap = 500 ,
486492):
487493 """Computes the confidence intervals of DTE.
@@ -497,7 +503,7 @@ def compute_dte_confidence_intervals(
497503 ind_target (int): Index of the target treatment indicator.
498504 ind_control (int): Index of the control treatment indicator.
499505 alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
500- variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are pointwise , analytic, and uniform.
506+ variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment , analytic, and uniform.
501507 n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
502508
503509 Returns:
@@ -530,7 +536,7 @@ def compute_dte_confidence_intervals(
530536
531537 omega = (influence_function ** 2 ).mean (axis = 0 )
532538
533- if variance_type == "pointwise " :
539+ if variance_type == "moment " :
534540 vec_dte_lower_moment = vec_dte + norm .ppf (alpha / 2 ) * np .sqrt (omega / num_obs )
535541 vec_dte_upper_moment = vec_dte + norm .ppf (1 - alpha / 2 ) * np .sqrt (
536542 omega / num_obs
0 commit comments