Skip to content

Commit fa89916

Browse files
committed
refactor: consolidate predict_qte in DistributionEstimatorBase
Move the bootstrap loop into the base class and switch it to stratified resampling (per-stratum np.random.choice). Stratified resampling on a single stratum is equivalent to plain bootstrap, so SimpleDistributionEstimator and AdjustedDistributionEstimator (which set strata to a constant) remain unchanged in behavior while the CAR-aware variants pick up the correct variance estimator without any override. This removes the duplicated predict_qte bodies in both stratified subclasses, leaving the only delta vs. the base implementation in the resampling step.
1 parent f512d91 commit fa89916

2 files changed

Lines changed: 17 additions & 157 deletions

File tree

dte_adj/base.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ def predict_qte(
180180
into how treatment affects different parts of the outcome distribution. For stratified
181181
estimators, the computation properly accounts for strata.
182182
183+
Variance is estimated by stratified bootstrap: indices are resampled with replacement
184+
within each stratum independently, which preserves per-stratum sample sizes and reflects
185+
the covariate-adaptive randomization (CAR) design. For estimators without strata
186+
(single stratum), this degenerates to a plain bootstrap.
187+
183188
Args:
184189
target_treatment_arm (int): The index of the treatment arm of the treatment group.
185190
control_treatment_arm (int): The index of the treatment arm of the control group.
@@ -236,15 +241,23 @@ def predict_qte(
236241
self.outcomes,
237242
self.strata,
238243
)
239-
n_obs = len(self.outcomes)
240-
indexes = np.arange(n_obs)
244+
245+
# Precompute stratum indices for stratified bootstrap.
246+
# When there is a single stratum this is equivalent to plain bootstrap.
247+
unique_strata = np.unique(self.strata)
248+
strata_indices = [np.where(self.strata == s)[0] for s in unique_strata]
241249

242250
qtes = np.zeros((n_bootstrap, qte.shape[0]))
243251
bootstrap_iter = range(n_bootstrap)
244252
if display_progress:
245253
bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE")
246254
for b in bootstrap_iter:
247-
bootstrap_indexes = np.random.choice(indexes, size=n_obs, replace=True)
255+
bootstrap_indexes = np.concatenate(
256+
[
257+
np.random.choice(idx, size=len(idx), replace=True)
258+
for idx in strata_indices
259+
]
260+
)
248261

249262
qtes[b] = self._compute_qtes(
250263
target_treatment_arm,

dte_adj/stratified.py

Lines changed: 1 addition & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from __future__ import annotations
22

33
import numpy as np
4-
from typing import Optional, Tuple, Any
4+
from typing import Tuple, Any
55
from copy import deepcopy
6-
from scipy.stats import norm
76
from tqdm.auto import tqdm
87
from dte_adj.base import DistributionEstimatorBase
98
from dte_adj.util import ArrayLike, _convert_to_ndarray
@@ -154,82 +153,6 @@ def _compute_interval_probability(
154153
conditional_prediction[:, 1:] - conditional_prediction[:, :-1],
155154
)
156155

157-
def predict_qte(
158-
self,
159-
target_treatment_arm: int,
160-
control_treatment_arm: int,
161-
quantiles: Optional[np.ndarray] = None,
162-
alpha: float = 0.05,
163-
n_bootstrap=500,
164-
display_progress: bool = True,
165-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
166-
"""
167-
Compute Quantile Treatment Effects (QTE) using stratified bootstrap.
168-
169-
Uses stratified bootstrap (resampling independently within each stratum) to
170-
correctly estimate variance under covariate adaptive randomization (CAR).
171-
172-
Args:
173-
target_treatment_arm (int): The index of the treatment arm of the treatment group.
174-
control_treatment_arm (int): The index of the treatment arm of the control group.
175-
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
176-
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
177-
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
178-
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
179-
180-
Returns:
181-
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
182-
- Expected QTEs (np.ndarray): Treatment effect estimates at each quantile
183-
- Lower bounds (np.ndarray): Lower confidence interval bounds
184-
- Upper bounds (np.ndarray): Upper confidence interval bounds
185-
"""
186-
if quantiles is None:
187-
quantiles = np.arange(1, 10) / 10
188-
if np.any((quantiles <= 0) | (quantiles >= 1)):
189-
raise ValueError("quantiles must be in the open interval (0, 1)")
190-
191-
qte = self._compute_qtes(
192-
target_treatment_arm,
193-
control_treatment_arm,
194-
quantiles,
195-
self.covariates,
196-
self.treatment_arms,
197-
self.outcomes,
198-
self.strata,
199-
)
200-
201-
# Precompute stratum indices for stratified bootstrap
202-
unique_strata = np.unique(self.strata)
203-
strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata}
204-
205-
qtes = np.zeros((n_bootstrap, qte.shape[0]))
206-
bootstrap_iter = range(n_bootstrap)
207-
if display_progress:
208-
bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE")
209-
for b in bootstrap_iter:
210-
# Stratified bootstrap: resample within each stratum independently
211-
bootstrap_indexes = np.concatenate([
212-
np.random.choice(idx, size=len(idx), replace=True)
213-
for idx in strata_indices.values()
214-
])
215-
216-
qtes[b] = self._compute_qtes(
217-
target_treatment_arm,
218-
control_treatment_arm,
219-
quantiles,
220-
self.covariates[bootstrap_indexes],
221-
self.treatment_arms[bootstrap_indexes],
222-
self.outcomes[bootstrap_indexes],
223-
self.strata[bootstrap_indexes],
224-
)
225-
226-
qte_var = qtes.var(axis=0)
227-
228-
qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var)
229-
qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var)
230-
231-
return qte, qte_lower, qte_upper
232-
233156

234157
class AdjustedStratifiedDistributionEstimator(DistributionEstimatorBase):
235158
"""A class is for estimating the adjusted distribution function and computing the Distributional parameters for CAR."""
@@ -482,82 +405,6 @@ def _compute_interval_probability(
482405

483406
return prediction.mean(axis=0), prediction, superset_prediction
484407

485-
def predict_qte(
486-
self,
487-
target_treatment_arm: int,
488-
control_treatment_arm: int,
489-
quantiles: Optional[np.ndarray] = None,
490-
alpha: float = 0.05,
491-
n_bootstrap=500,
492-
display_progress: bool = True,
493-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
494-
"""
495-
Compute Quantile Treatment Effects (QTE) using stratified bootstrap.
496-
497-
Uses stratified bootstrap (resampling independently within each stratum) to
498-
correctly estimate variance under covariate adaptive randomization (CAR).
499-
500-
Args:
501-
target_treatment_arm (int): The index of the treatment arm of the treatment group.
502-
control_treatment_arm (int): The index of the treatment arm of the control group.
503-
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
504-
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
505-
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
506-
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
507-
508-
Returns:
509-
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
510-
- Expected QTEs (np.ndarray): Treatment effect estimates at each quantile
511-
- Lower bounds (np.ndarray): Lower confidence interval bounds
512-
- Upper bounds (np.ndarray): Upper confidence interval bounds
513-
"""
514-
if quantiles is None:
515-
quantiles = np.arange(1, 10) / 10
516-
if np.any((quantiles <= 0) | (quantiles >= 1)):
517-
raise ValueError("quantiles must be in the open interval (0, 1)")
518-
519-
qte = self._compute_qtes(
520-
target_treatment_arm,
521-
control_treatment_arm,
522-
quantiles,
523-
self.covariates,
524-
self.treatment_arms,
525-
self.outcomes,
526-
self.strata,
527-
)
528-
529-
# Precompute stratum indices for stratified bootstrap
530-
unique_strata = np.unique(self.strata)
531-
strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata}
532-
533-
qtes = np.zeros((n_bootstrap, qte.shape[0]))
534-
bootstrap_iter = range(n_bootstrap)
535-
if display_progress:
536-
bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE")
537-
for b in bootstrap_iter:
538-
# Stratified bootstrap: resample within each stratum independently
539-
bootstrap_indexes = np.concatenate([
540-
np.random.choice(idx, size=len(idx), replace=True)
541-
for idx in strata_indices.values()
542-
])
543-
544-
qtes[b] = self._compute_qtes(
545-
target_treatment_arm,
546-
control_treatment_arm,
547-
quantiles,
548-
self.covariates[bootstrap_indexes],
549-
self.treatment_arms[bootstrap_indexes],
550-
self.outcomes[bootstrap_indexes],
551-
self.strata[bootstrap_indexes],
552-
)
553-
554-
qte_var = qtes.var(axis=0)
555-
556-
qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var)
557-
qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var)
558-
559-
return qte, qte_lower, qte_upper
560-
561408
def _compute_model_prediction(self, model, covariates: np.ndarray) -> np.ndarray:
562409
if hasattr(model, "predict_proba"):
563410
if self.is_multi_task:

0 commit comments

Comments
 (0)