22from typing import Tuple , Optional
33from scipy .stats import norm
44from abc import ABC
5+ from tqdm .auto import tqdm
56import dte_adj
67
78
@@ -27,6 +28,7 @@ def predict_dte(
2728 alpha : float = 0.05 ,
2829 variance_type = "moment" ,
2930 n_bootstrap = 500 ,
31+ display_progress : bool = True ,
3032 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
3133 """
3234 Compute Distribution Treatment Effects (DTE) based on the estimator for the distribution function.
@@ -43,6 +45,7 @@ def predict_dte(
4345 variance_type (str, optional): Variance type to be used to compute confidence intervals.
4446 Available values are "moment", "simple", and "uniform". Defaults to "moment".
4547 n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
48+ display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
4649
4750 Returns:
4851 Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -84,6 +87,7 @@ def predict_dte(
8487 alpha ,
8588 variance_type ,
8689 n_bootstrap ,
90+ display_progress ,
8791 )
8892
8993 def predict_pte (
@@ -94,6 +98,7 @@ def predict_pte(
9498 alpha : float = 0.05 ,
9599 variance_type = "moment" ,
96100 n_bootstrap = 500 ,
101+ display_progress : bool = True ,
97102 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
98103 """
99104 Compute Probability Treatment Effects (PTE) based on the estimator for the distribution function.
@@ -111,6 +116,7 @@ def predict_pte(
111116 variance_type (str, optional): Variance type to be used to compute confidence intervals.
112117 Available values are "moment", "simple", and "uniform". Defaults to "moment".
113118 n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
119+ display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
114120
115121 Returns:
116122 Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -155,6 +161,7 @@ def predict_pte(
155161 alpha ,
156162 variance_type ,
157163 n_bootstrap ,
164+ display_progress ,
158165 )
159166
160167 def predict_qte (
@@ -164,6 +171,7 @@ def predict_qte(
164171 quantiles : Optional [np .ndarray ] = None ,
165172 alpha : float = 0.05 ,
166173 n_bootstrap = 500 ,
174+ display_progress : bool = True ,
167175 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
168176 """
169177 Compute Quantile Treatment Effects (QTE) based on the estimator for the distribution function.
@@ -178,6 +186,7 @@ def predict_qte(
178186 quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
179187 alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
180188 n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
189+ display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
181190
182191 Returns:
183192 Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -226,7 +235,10 @@ def predict_qte(
226235 indexes = np .arange (n_obs )
227236
228237 qtes = np .zeros ((n_bootstrap , qte .shape [0 ]))
229- for b in range (n_bootstrap ):
238+ bootstrap_iter = range (n_bootstrap )
239+ if display_progress :
240+ bootstrap_iter = tqdm (bootstrap_iter , desc = "Bootstrap QTE" )
241+ for b in bootstrap_iter :
230242 bootstrap_indexes = np .random .choice (indexes , size = n_obs , replace = True )
231243
232244 qtes [b ] = self ._compute_qtes (
@@ -254,6 +266,7 @@ def _compute_dtes(
254266 alpha : float ,
255267 variance_type : str ,
256268 n_bootstrap : int ,
269+ display_progress : bool = False ,
257270 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
258271 """Compute expected DTEs."""
259272 treatment_cdf , treatment_cdf_mat , _ = self ._compute_cumulative_distribution (
@@ -262,13 +275,15 @@ def _compute_dtes(
262275 self .covariates ,
263276 self .treatment_arms ,
264277 self .outcomes ,
278+ display_progress = display_progress ,
265279 )
266280 control_cdf , control_cdf_mat , _ = self ._compute_cumulative_distribution (
267281 control_treatment_arm ,
268282 locations ,
269283 self .covariates ,
270284 self .treatment_arms ,
271285 self .outcomes ,
286+ display_progress = display_progress ,
272287 )
273288
274289 dte = treatment_cdf - control_cdf
@@ -305,6 +320,7 @@ def _compute_ptes(
305320 alpha : float ,
306321 variance_type : str ,
307322 n_bootstrap : int ,
323+ display_progress : bool = False ,
308324 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
309325 """Compute expected PTEs."""
310326 treatment_pdf , treatment_pdf_mat , _ = self ._compute_interval_probability (
@@ -313,13 +329,15 @@ def _compute_ptes(
313329 self .covariates ,
314330 self .treatment_arms ,
315331 self .outcomes ,
332+ display_progress = display_progress ,
316333 )
317334 control_pdf , control_pdf_mat , _ = self ._compute_interval_probability (
318335 control_treatment_arm ,
319336 locations ,
320337 self .covariates ,
321338 self .treatment_arms ,
322339 self .outcomes ,
340+ display_progress = display_progress ,
323341 )
324342
325343 pte = treatment_pdf - control_pdf
@@ -398,13 +416,16 @@ def find_quantile(quantile, arm):
398416
399417 return result
400418
401- def predict (self , treatment_arm : int , locations : np .ndarray ) -> np .ndarray :
419+ def predict (
420+ self , treatment_arm : int , locations : np .ndarray , display_progress : bool = True
421+ ) -> np .ndarray :
402422 """
403423 Compute cumulative distribution values.
404424
405425 Args:
406426 treatment_arm (int): The index of the treatment arm.
407427 outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
428+ display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
408429
409430 Returns:
410431 np.ndarray: Estimated cumulative distribution values for the input.
@@ -425,6 +446,7 @@ def predict(self, treatment_arm: int, locations: np.ndarray) -> np.ndarray:
425446 self .covariates ,
426447 self .treatment_arms ,
427448 self .outcomes ,
449+ display_progress = display_progress ,
428450 )[0 ]
429451
430452 def _compute_cumulative_distribution (
@@ -434,6 +456,7 @@ def _compute_cumulative_distribution(
434456 covariates : np .ndarray ,
435457 treatment_arms : np .ndarray ,
436458 outcomes : np .array ,
459+ display_progress : bool = False ,
437460 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
438461 """
439462 Compute the cumulative distribution values.
@@ -444,6 +467,7 @@ def _compute_cumulative_distribution(
444467 covariates: (np.ndarray): An array of covariates variables in the observed data.
445468 treatment_arms (np.ndarray): An array of treatment arms in the observed data.
446469 outcomes (np.ndarray): An array of outcomes in the observed data.
470+ display_progress (bool): Whether to display a progress bar.
447471
448472 Returns:
449473 Tuple[np.ndarray, np.ndarray, np.ndarray]: Estimated cumulative distribution values, prediction for each observation, and superset prediction for each observation.
0 commit comments