|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | | -from typing import Tuple, Any |
| 4 | +from typing import Optional, Tuple, Any |
5 | 5 | from copy import deepcopy |
| 6 | +from scipy.stats import norm |
6 | 7 | from tqdm.auto import tqdm |
7 | 8 | from dte_adj.base import DistributionEstimatorBase |
8 | 9 | from dte_adj.util import ArrayLike, _convert_to_ndarray |
@@ -153,6 +154,77 @@ def _compute_interval_probability( |
153 | 154 | conditional_prediction[:, 1:] - conditional_prediction[:, :-1], |
154 | 155 | ) |
155 | 156 |
|
| 157 | + def predict_qte( |
| 158 | + self, |
| 159 | + target_treatment_arm: int, |
| 160 | + control_treatment_arm: int, |
| 161 | + quantiles: Optional[np.ndarray] = None, |
| 162 | + alpha: float = 0.05, |
| 163 | + n_bootstrap=500, |
| 164 | + display_progress: bool = True, |
| 165 | + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 166 | + """ |
| 167 | + Compute Quantile Treatment Effects (QTE) using stratified bootstrap. |
| 168 | +
|
| 169 | + Uses stratified bootstrap (resampling independently within each stratum) to |
| 170 | + correctly estimate variance under covariate adaptive randomization (CAR). |
| 171 | +
|
| 172 | + Args: |
| 173 | + target_treatment_arm (int): The index of the treatment arm of the treatment group. |
| 174 | + control_treatment_arm (int): The index of the treatment arm of the control group. |
| 175 | + quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. |
| 176 | + alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. |
| 177 | + n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. |
| 178 | + display_progress (bool, optional): Whether to display a progress bar. Defaults to True. |
| 179 | +
|
| 180 | + Returns: |
| 181 | + Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: |
| 182 | + - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile |
| 183 | + - Lower bounds (np.ndarray): Lower confidence interval bounds |
| 184 | + - Upper bounds (np.ndarray): Upper confidence interval bounds |
| 185 | + """ |
| 186 | + qte = self._compute_qtes( |
| 187 | + target_treatment_arm, |
| 188 | + control_treatment_arm, |
| 189 | + quantiles, |
| 190 | + self.covariates, |
| 191 | + self.treatment_arms, |
| 192 | + self.outcomes, |
| 193 | + self.strata, |
| 194 | + ) |
| 195 | + |
| 196 | + # Precompute stratum indices for stratified bootstrap |
| 197 | + unique_strata = np.unique(self.strata) |
| 198 | + strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} |
| 199 | + |
| 200 | + qtes = np.zeros((n_bootstrap, qte.shape[0])) |
| 201 | + bootstrap_iter = range(n_bootstrap) |
| 202 | + if display_progress: |
| 203 | + bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") |
| 204 | + for b in bootstrap_iter: |
| 205 | + # Stratified bootstrap: resample within each stratum independently |
| 206 | + bootstrap_indexes = np.concatenate([ |
| 207 | + np.random.choice(idx, size=len(idx), replace=True) |
| 208 | + for idx in strata_indices.values() |
| 209 | + ]) |
| 210 | + |
| 211 | + qtes[b] = self._compute_qtes( |
| 212 | + target_treatment_arm, |
| 213 | + control_treatment_arm, |
| 214 | + quantiles, |
| 215 | + self.covariates[bootstrap_indexes], |
| 216 | + self.treatment_arms[bootstrap_indexes], |
| 217 | + self.outcomes[bootstrap_indexes], |
| 218 | + self.strata[bootstrap_indexes], |
| 219 | + ) |
| 220 | + |
| 221 | + qte_var = qtes.var(axis=0) |
| 222 | + |
| 223 | + qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) |
| 224 | + qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) |
| 225 | + |
| 226 | + return qte, qte_lower, qte_upper |
| 227 | + |
156 | 228 |
|
157 | 229 | class AdjustedStratifiedDistributionEstimator(DistributionEstimatorBase): |
158 | 230 | """A class is for estimating the adjusted distribution function and computing the Distributional parameters for CAR.""" |
@@ -405,6 +477,77 @@ def _compute_interval_probability( |
405 | 477 |
|
406 | 478 | return prediction.mean(axis=0), prediction, superset_prediction |
407 | 479 |
|
| 480 | + def predict_qte( |
| 481 | + self, |
| 482 | + target_treatment_arm: int, |
| 483 | + control_treatment_arm: int, |
| 484 | + quantiles: Optional[np.ndarray] = None, |
| 485 | + alpha: float = 0.05, |
| 486 | + n_bootstrap=500, |
| 487 | + display_progress: bool = True, |
| 488 | + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 489 | + """ |
| 490 | + Compute Quantile Treatment Effects (QTE) using stratified bootstrap. |
| 491 | +
|
| 492 | + Uses stratified bootstrap (resampling independently within each stratum) to |
| 493 | + correctly estimate variance under covariate adaptive randomization (CAR). |
| 494 | +
|
| 495 | + Args: |
| 496 | + target_treatment_arm (int): The index of the treatment arm of the treatment group. |
| 497 | + control_treatment_arm (int): The index of the treatment arm of the control group. |
| 498 | + quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. |
| 499 | + alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. |
| 500 | + n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. |
| 501 | + display_progress (bool, optional): Whether to display a progress bar. Defaults to True. |
| 502 | +
|
| 503 | + Returns: |
| 504 | + Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: |
| 505 | + - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile |
| 506 | + - Lower bounds (np.ndarray): Lower confidence interval bounds |
| 507 | + - Upper bounds (np.ndarray): Upper confidence interval bounds |
| 508 | + """ |
| 509 | + qte = self._compute_qtes( |
| 510 | + target_treatment_arm, |
| 511 | + control_treatment_arm, |
| 512 | + quantiles, |
| 513 | + self.covariates, |
| 514 | + self.treatment_arms, |
| 515 | + self.outcomes, |
| 516 | + self.strata, |
| 517 | + ) |
| 518 | + |
| 519 | + # Precompute stratum indices for stratified bootstrap |
| 520 | + unique_strata = np.unique(self.strata) |
| 521 | + strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} |
| 522 | + |
| 523 | + qtes = np.zeros((n_bootstrap, qte.shape[0])) |
| 524 | + bootstrap_iter = range(n_bootstrap) |
| 525 | + if display_progress: |
| 526 | + bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") |
| 527 | + for b in bootstrap_iter: |
| 528 | + # Stratified bootstrap: resample within each stratum independently |
| 529 | + bootstrap_indexes = np.concatenate([ |
| 530 | + np.random.choice(idx, size=len(idx), replace=True) |
| 531 | + for idx in strata_indices.values() |
| 532 | + ]) |
| 533 | + |
| 534 | + qtes[b] = self._compute_qtes( |
| 535 | + target_treatment_arm, |
| 536 | + control_treatment_arm, |
| 537 | + quantiles, |
| 538 | + self.covariates[bootstrap_indexes], |
| 539 | + self.treatment_arms[bootstrap_indexes], |
| 540 | + self.outcomes[bootstrap_indexes], |
| 541 | + self.strata[bootstrap_indexes], |
| 542 | + ) |
| 543 | + |
| 544 | + qte_var = qtes.var(axis=0) |
| 545 | + |
| 546 | + qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) |
| 547 | + qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) |
| 548 | + |
| 549 | + return qte, qte_lower, qte_upper |
| 550 | + |
408 | 551 | def _compute_model_prediction(self, model, covariates: np.ndarray) -> np.ndarray: |
409 | 552 | if hasattr(model, "predict_proba"): |
410 | 553 | if self.is_multi_task: |
|
0 commit comments