Skip to content

Commit a6c7c14

Browse files
authored
feat: add progress bar to predict methods (#63) (#94)
* feat: add progress bar to predict methods (#63) Add tqdm-based progress bars to long-running predict methods so users can track ETA. A `verbose` parameter (default True) controls display. Affected loops: - _compute_cumulative_distribution (locations / folds iteration) - _compute_interval_probability (interval iteration) - predict_qte bootstrap loop * Rename verbose to display_progress
1 parent 5111a47 commit a6c7c14

7 files changed

Lines changed: 1033 additions & 956 deletions

File tree

dte_adj/base.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Tuple, Optional
33
from scipy.stats import norm
44
from abc import ABC
5+
from tqdm.auto import tqdm
56
import dte_adj
67

78

@@ -27,6 +28,7 @@ def predict_dte(
2728
alpha: float = 0.05,
2829
variance_type="moment",
2930
n_bootstrap=500,
31+
display_progress: bool = True,
3032
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
3133
"""
3234
Compute Distribution Treatment Effects (DTE) based on the estimator for the distribution function.
@@ -43,6 +45,7 @@ def predict_dte(
4345
variance_type (str, optional): Variance type to be used to compute confidence intervals.
4446
Available values are "moment", "simple", and "uniform". Defaults to "moment".
4547
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
48+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
4649
4750
Returns:
4851
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -84,6 +87,7 @@ def predict_dte(
8487
alpha,
8588
variance_type,
8689
n_bootstrap,
90+
display_progress,
8791
)
8892

8993
def predict_pte(
@@ -94,6 +98,7 @@ def predict_pte(
9498
alpha: float = 0.05,
9599
variance_type="moment",
96100
n_bootstrap=500,
101+
display_progress: bool = True,
97102
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
98103
"""
99104
Compute Probability Treatment Effects (PTE) based on the estimator for the distribution function.
@@ -111,6 +116,7 @@ def predict_pte(
111116
variance_type (str, optional): Variance type to be used to compute confidence intervals.
112117
Available values are "moment", "simple", and "uniform". Defaults to "moment".
113118
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
119+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
114120
115121
Returns:
116122
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -155,6 +161,7 @@ def predict_pte(
155161
alpha,
156162
variance_type,
157163
n_bootstrap,
164+
display_progress,
158165
)
159166

160167
def predict_qte(
@@ -164,6 +171,7 @@ def predict_qte(
164171
quantiles: Optional[np.ndarray] = None,
165172
alpha: float = 0.05,
166173
n_bootstrap=500,
174+
display_progress: bool = True,
167175
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
168176
"""
169177
Compute Quantile Treatment Effects (QTE) based on the estimator for the distribution function.
@@ -178,6 +186,7 @@ def predict_qte(
178186
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
179187
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
180188
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
189+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
181190
182191
Returns:
183192
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -226,7 +235,10 @@ def predict_qte(
226235
indexes = np.arange(n_obs)
227236

228237
qtes = np.zeros((n_bootstrap, qte.shape[0]))
229-
for b in range(n_bootstrap):
238+
bootstrap_iter = range(n_bootstrap)
239+
if display_progress:
240+
bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE")
241+
for b in bootstrap_iter:
230242
bootstrap_indexes = np.random.choice(indexes, size=n_obs, replace=True)
231243

232244
qtes[b] = self._compute_qtes(
@@ -254,6 +266,7 @@ def _compute_dtes(
254266
alpha: float,
255267
variance_type: str,
256268
n_bootstrap: int,
269+
display_progress: bool = False,
257270
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
258271
"""Compute expected DTEs."""
259272
treatment_cdf, treatment_cdf_mat, _ = self._compute_cumulative_distribution(
@@ -262,13 +275,15 @@ def _compute_dtes(
262275
self.covariates,
263276
self.treatment_arms,
264277
self.outcomes,
278+
display_progress=display_progress,
265279
)
266280
control_cdf, control_cdf_mat, _ = self._compute_cumulative_distribution(
267281
control_treatment_arm,
268282
locations,
269283
self.covariates,
270284
self.treatment_arms,
271285
self.outcomes,
286+
display_progress=display_progress,
272287
)
273288

274289
dte = treatment_cdf - control_cdf
@@ -305,6 +320,7 @@ def _compute_ptes(
305320
alpha: float,
306321
variance_type: str,
307322
n_bootstrap: int,
323+
display_progress: bool = False,
308324
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
309325
"""Compute expected PTEs."""
310326
treatment_pdf, treatment_pdf_mat, _ = self._compute_interval_probability(
@@ -313,13 +329,15 @@ def _compute_ptes(
313329
self.covariates,
314330
self.treatment_arms,
315331
self.outcomes,
332+
display_progress=display_progress,
316333
)
317334
control_pdf, control_pdf_mat, _ = self._compute_interval_probability(
318335
control_treatment_arm,
319336
locations,
320337
self.covariates,
321338
self.treatment_arms,
322339
self.outcomes,
340+
display_progress=display_progress,
323341
)
324342

325343
pte = treatment_pdf - control_pdf
@@ -398,13 +416,16 @@ def find_quantile(quantile, arm):
398416

399417
return result
400418

401-
def predict(self, treatment_arm: int, locations: np.ndarray) -> np.ndarray:
419+
def predict(
420+
self, treatment_arm: int, locations: np.ndarray, display_progress: bool = True
421+
) -> np.ndarray:
402422
"""
403423
Compute cumulative distribution values.
404424
405425
Args:
406426
treatment_arm (int): The index of the treatment arm.
407427
outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
428+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
408429
409430
Returns:
410431
np.ndarray: Estimated cumulative distribution values for the input.
@@ -425,6 +446,7 @@ def predict(self, treatment_arm: int, locations: np.ndarray) -> np.ndarray:
425446
self.covariates,
426447
self.treatment_arms,
427448
self.outcomes,
449+
display_progress=display_progress,
428450
)[0]
429451

430452
def _compute_cumulative_distribution(
@@ -434,6 +456,7 @@ def _compute_cumulative_distribution(
434456
covariates: np.ndarray,
435457
treatment_arms: np.ndarray,
436458
outcomes: np.array,
459+
display_progress: bool = False,
437460
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
438461
"""
439462
Compute the cumulative distribution values.
@@ -444,6 +467,7 @@ def _compute_cumulative_distribution(
444467
covariates: (np.ndarray): An array of covariates variables in the observed data.
445468
treatment_arms (np.ndarray): An array of treatment arms in the observed data.
446469
outcomes (np.ndarray): An array of outcomes in the observed data.
470+
display_progress (bool): Whether to display a progress bar.
447471
448472
Returns:
449473
Tuple[np.ndarray, np.ndarray, np.ndarray]: Estimated cumulative distribution values, prediction for each observation, and superset prediction for each observation.

dte_adj/local.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def predict_ldte(
6161
control_treatment_arm: int,
6262
locations: np.ndarray,
6363
alpha: float = 0.05,
64+
display_progress: bool = True,
6465
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
6566
"""
6667
Compute Local Distribution Treatment Effects (LDTE).
@@ -74,6 +75,7 @@ def predict_ldte(
7475
control_treatment_arm (int): The index of the treatment arm of the control group.
7576
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
7677
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
78+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
7779
7880
Returns:
7981
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -117,6 +119,7 @@ def predict_ldte(
117119
control_treatment_arm,
118120
locations,
119121
alpha,
122+
display_progress,
120123
)
121124

122125
def predict_lpte(
@@ -125,6 +128,7 @@ def predict_lpte(
125128
control_treatment_arm: int,
126129
locations: np.ndarray,
127130
alpha: float = 0.05,
131+
display_progress: bool = True,
128132
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
129133
"""
130134
Compute Local Probability Treatment Effects (LPTE).
@@ -139,6 +143,7 @@ def predict_lpte(
139143
locations (np.ndarray): Scalar values defining interval boundaries for probability computation.
140144
For each interval (locations[i], locations[i+1]], the LPTE is computed.
141145
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
146+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
142147
143148
Returns:
144149
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -184,6 +189,7 @@ def predict_lpte(
184189
control_treatment_arm,
185190
locations,
186191
alpha,
192+
display_progress,
187193
)
188194

189195

@@ -230,6 +236,7 @@ def predict_ldte(
230236
control_treatment_arm: int,
231237
locations: np.ndarray,
232238
alpha: float = 0.05,
239+
display_progress: bool = True,
233240
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
234241
"""
235242
Compute Local Distribution Treatment Effects (LDTE) using ML adjustment.
@@ -242,6 +249,7 @@ def predict_ldte(
242249
control_treatment_arm (int): The index of the treatment arm of the control group.
243250
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
244251
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
252+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
245253
246254
Returns:
247255
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -287,6 +295,7 @@ def predict_ldte(
287295
control_treatment_arm,
288296
locations,
289297
alpha,
298+
display_progress,
290299
)
291300

292301
def predict_lpte(
@@ -295,6 +304,7 @@ def predict_lpte(
295304
control_treatment_arm: int,
296305
locations: np.ndarray,
297306
alpha: float = 0.05,
307+
display_progress: bool = True,
298308
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
299309
"""
300310
Compute Local Probability Treatment Effects (LPTE) using ML adjustment.
@@ -308,6 +318,7 @@ def predict_lpte(
308318
locations (np.ndarray): Scalar values defining interval boundaries for probability computation.
309319
For each interval (locations[i], locations[i+1]], the LPTE is computed.
310320
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
321+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
311322
312323
Returns:
313324
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
@@ -356,4 +367,5 @@ def predict_lpte(
356367
control_treatment_arm,
357368
locations,
358369
alpha,
370+
display_progress,
359371
)

dte_adj/stratified.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from typing import Tuple, Any
55
from copy import deepcopy
6+
from tqdm.auto import tqdm
67
from dte_adj.base import DistributionEstimatorBase
78
from dte_adj.util import ArrayLike, _convert_to_ndarray
89

@@ -54,6 +55,7 @@ def _compute_cumulative_distribution(
5455
covariates: np.ndarray,
5556
treatment_arms: np.ndarray,
5657
outcomes: np.array,
58+
display_progress: bool = False,
5759
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
5860
"""
5961
Compute the cumulative distribution values.
@@ -64,6 +66,7 @@ def _compute_cumulative_distribution(
6466
covariates: (np.ndarray): An array of covariates variables in the observed data.
6567
treatment_arm (np.ndarray): An array of treatment arms in the observed data.
6668
outcomes (np.ndarray): An array of outcomes in the observed data
69+
display_progress (bool): Whether to display a progress bar.
6770
6871
Returns:
6972
Tuple of numpy arrays:
@@ -102,6 +105,7 @@ def _compute_interval_probability(
102105
covariates: np.ndarray,
103106
treatment_arms: np.ndarray,
104107
outcomes: np.array,
108+
display_progress: bool = False,
105109
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
106110
"""Compute the interval probabilities.
107111
@@ -111,6 +115,7 @@ def _compute_interval_probability(
111115
covariates: (np.ndarray): An array of covariates variables in the observed data.
112116
treatment_arm (np.ndarray): An array of treatment arms in the observed data.
113117
outcomes (np.ndarray): An array of outcomes in the observed data
118+
display_progress (bool): Whether to display a progress bar.
114119
115120
Returns:
116121
Tuple of numpy arrays:
@@ -219,6 +224,7 @@ def _compute_cumulative_distribution(
219224
covariates: np.ndarray,
220225
treatment_arms: np.ndarray,
221226
outcomes: np.array,
227+
display_progress: bool = False,
222228
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
223229
"""
224230
Compute the cumulative distribution values.
@@ -229,6 +235,7 @@ def _compute_cumulative_distribution(
229235
covariates: (np.ndarray): An array of covariates variables in the observed data.
230236
treatment_arm (np.ndarray): An array of treatment arms in the observed data.
231237
outcomes (np.ndarray): An array of outcomes in the observed data
238+
display_progress (bool): Whether to display a progress bar.
232239
233240
Returns:
234241
Tuple of numpy arrays:
@@ -246,7 +253,10 @@ def _compute_cumulative_distribution(
246253
s_list = np.unique(strata)
247254
if self.is_multi_task:
248255
binomial = (outcomes.reshape(-1, 1) <= locations) * 1 # (n_records, n_loc)
249-
for fold in range(self.folds):
256+
fold_iter = range(self.folds)
257+
if display_progress:
258+
fold_iter = tqdm(fold_iter, desc="Cross-fitting (multi-task)")
259+
for fold in fold_iter:
250260
fold_mask = (folds != fold) & treatment_mask
251261
for s in s_list:
252262
s_mask = strata == s
@@ -270,7 +280,10 @@ def _compute_cumulative_distribution(
270280
)
271281
superset_prediction[superset_mask] = pred
272282
else:
273-
for i, location in enumerate(locations):
283+
loc_iter = enumerate(locations)
284+
if display_progress:
285+
loc_iter = tqdm(loc_iter, total=len(locations), desc="Computing CDF")
286+
for i, location in loc_iter:
274287
binomial = (outcomes <= location) * 1 # (n_records)
275288
for fold in range(self.folds):
276289
fold_mask = (folds != fold) & treatment_mask
@@ -322,6 +335,7 @@ def _compute_interval_probability(
322335
covariates: np.ndarray,
323336
treatment_arms: np.ndarray,
324337
outcomes: np.array,
338+
display_progress: bool = False,
325339
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
326340
"""
327341
Compute the interval probabilities.
@@ -332,6 +346,7 @@ def _compute_interval_probability(
332346
covariates: (np.ndarray): An array of covariates variables in the observed data.
333347
treatment_arm (np.ndarray): An array of treatment arms in the observed data.
334348
outcomes (np.ndarray): An array of outcomes in the observed data
349+
display_progress (bool): Whether to display a progress bar.
335350
336351
Returns:
337352
Tuple of numpy arrays:
@@ -348,7 +363,10 @@ def _compute_interval_probability(
348363
strata = self.strata
349364
s_list = np.unique(strata)
350365
binominals = (outcomes[:, np.newaxis] <= locations) * 1 # (n_records, n_loc)
351-
for i in range(len(locations) - 1):
366+
interval_iter = range(len(locations) - 1)
367+
if display_progress:
368+
interval_iter = tqdm(interval_iter, desc="Computing interval prob.")
369+
for i in interval_iter:
352370
binomial = binominals[:, i + 1] - binominals[:, i]
353371
for fold in range(self.folds):
354372
fold_mask = (folds != fold) & treatment_mask

0 commit comments

Comments
 (0)