Skip to content

Commit 4bd2594

Browse files
committed
add QTE support for covariate adaptive randomization
Implement predict_qte in SimpleStratifiedDistributionEstimator and AdjustedStratifiedDistributionEstimator with stratified bootstrap (resampling within each stratum independently) to correctly estimate variance under CAR designs.
1 parent a4c6eff commit 4bd2594

1 file changed

Lines changed: 144 additions & 1 deletion

File tree

dte_adj/stratified.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import numpy as np
4-
from typing import Tuple, Any
4+
from typing import Optional, Tuple, Any
55
from copy import deepcopy
6+
from scipy.stats import norm
67
from tqdm.auto import tqdm
78
from dte_adj.base import DistributionEstimatorBase
89
from dte_adj.util import ArrayLike, _convert_to_ndarray
@@ -153,6 +154,77 @@ def _compute_interval_probability(
153154
conditional_prediction[:, 1:] - conditional_prediction[:, :-1],
154155
)
155156

157+
def predict_qte(
158+
self,
159+
target_treatment_arm: int,
160+
control_treatment_arm: int,
161+
quantiles: Optional[np.ndarray] = None,
162+
alpha: float = 0.05,
163+
n_bootstrap=500,
164+
display_progress: bool = True,
165+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
166+
"""
167+
Compute Quantile Treatment Effects (QTE) using stratified bootstrap.
168+
169+
Uses stratified bootstrap (resampling independently within each stratum) to
170+
correctly estimate variance under covariate adaptive randomization (CAR).
171+
172+
Args:
173+
target_treatment_arm (int): The index of the treatment arm of the treatment group.
174+
control_treatment_arm (int): The index of the treatment arm of the control group.
175+
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
176+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
177+
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
178+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
179+
180+
Returns:
181+
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
182+
- Expected QTEs (np.ndarray): Treatment effect estimates at each quantile
183+
- Lower bounds (np.ndarray): Lower confidence interval bounds
184+
- Upper bounds (np.ndarray): Upper confidence interval bounds
185+
"""
186+
qte = self._compute_qtes(
187+
target_treatment_arm,
188+
control_treatment_arm,
189+
quantiles,
190+
self.covariates,
191+
self.treatment_arms,
192+
self.outcomes,
193+
self.strata,
194+
)
195+
196+
# Precompute stratum indices for stratified bootstrap
197+
unique_strata = np.unique(self.strata)
198+
strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata}
199+
200+
qtes = np.zeros((n_bootstrap, qte.shape[0]))
201+
bootstrap_iter = range(n_bootstrap)
202+
if display_progress:
203+
bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE")
204+
for b in bootstrap_iter:
205+
# Stratified bootstrap: resample within each stratum independently
206+
bootstrap_indexes = np.concatenate([
207+
np.random.choice(idx, size=len(idx), replace=True)
208+
for idx in strata_indices.values()
209+
])
210+
211+
qtes[b] = self._compute_qtes(
212+
target_treatment_arm,
213+
control_treatment_arm,
214+
quantiles,
215+
self.covariates[bootstrap_indexes],
216+
self.treatment_arms[bootstrap_indexes],
217+
self.outcomes[bootstrap_indexes],
218+
self.strata[bootstrap_indexes],
219+
)
220+
221+
qte_var = qtes.var(axis=0)
222+
223+
qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var)
224+
qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var)
225+
226+
return qte, qte_lower, qte_upper
227+
156228

157229
class AdjustedStratifiedDistributionEstimator(DistributionEstimatorBase):
158230
"""A class is for estimating the adjusted distribution function and computing the Distributional parameters for CAR."""
@@ -405,6 +477,77 @@ def _compute_interval_probability(
405477

406478
return prediction.mean(axis=0), prediction, superset_prediction
407479

480+
def predict_qte(
481+
self,
482+
target_treatment_arm: int,
483+
control_treatment_arm: int,
484+
quantiles: Optional[np.ndarray] = None,
485+
alpha: float = 0.05,
486+
n_bootstrap=500,
487+
display_progress: bool = True,
488+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
489+
"""
490+
Compute Quantile Treatment Effects (QTE) using stratified bootstrap.
491+
492+
Uses stratified bootstrap (resampling independently within each stratum) to
493+
correctly estimate variance under covariate adaptive randomization (CAR).
494+
495+
Args:
496+
target_treatment_arm (int): The index of the treatment arm of the treatment group.
497+
control_treatment_arm (int): The index of the treatment arm of the control group.
498+
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
499+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
500+
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
501+
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.
502+
503+
Returns:
504+
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
505+
- Expected QTEs (np.ndarray): Treatment effect estimates at each quantile
506+
- Lower bounds (np.ndarray): Lower confidence interval bounds
507+
- Upper bounds (np.ndarray): Upper confidence interval bounds
508+
"""
509+
qte = self._compute_qtes(
510+
target_treatment_arm,
511+
control_treatment_arm,
512+
quantiles,
513+
self.covariates,
514+
self.treatment_arms,
515+
self.outcomes,
516+
self.strata,
517+
)
518+
519+
# Precompute stratum indices for stratified bootstrap
520+
unique_strata = np.unique(self.strata)
521+
strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata}
522+
523+
qtes = np.zeros((n_bootstrap, qte.shape[0]))
524+
bootstrap_iter = range(n_bootstrap)
525+
if display_progress:
526+
bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE")
527+
for b in bootstrap_iter:
528+
# Stratified bootstrap: resample within each stratum independently
529+
bootstrap_indexes = np.concatenate([
530+
np.random.choice(idx, size=len(idx), replace=True)
531+
for idx in strata_indices.values()
532+
])
533+
534+
qtes[b] = self._compute_qtes(
535+
target_treatment_arm,
536+
control_treatment_arm,
537+
quantiles,
538+
self.covariates[bootstrap_indexes],
539+
self.treatment_arms[bootstrap_indexes],
540+
self.outcomes[bootstrap_indexes],
541+
self.strata[bootstrap_indexes],
542+
)
543+
544+
qte_var = qtes.var(axis=0)
545+
546+
qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var)
547+
qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var)
548+
549+
return qte, qte_lower, qte_upper
550+
408551
def _compute_model_prediction(self, model, covariates: np.ndarray) -> np.ndarray:
409552
if hasattr(model, "predict_proba"):
410553
if self.is_multi_task:

0 commit comments

Comments
 (0)