|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | | -from typing import Optional, Tuple, Any |
| 4 | +from typing import Tuple, Any |
5 | 5 | from copy import deepcopy |
6 | | -from scipy.stats import norm |
7 | 6 | from tqdm.auto import tqdm |
8 | 7 | from dte_adj.base import DistributionEstimatorBase |
9 | 8 | from dte_adj.util import ArrayLike, _convert_to_ndarray |
@@ -154,82 +153,6 @@ def _compute_interval_probability( |
154 | 153 | conditional_prediction[:, 1:] - conditional_prediction[:, :-1], |
155 | 154 | ) |
156 | 155 |
|
157 | | - def predict_qte( |
158 | | - self, |
159 | | - target_treatment_arm: int, |
160 | | - control_treatment_arm: int, |
161 | | - quantiles: Optional[np.ndarray] = None, |
162 | | - alpha: float = 0.05, |
163 | | - n_bootstrap=500, |
164 | | - display_progress: bool = True, |
165 | | - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
166 | | - """ |
167 | | - Compute Quantile Treatment Effects (QTE) using stratified bootstrap. |
168 | | -
|
169 | | - Uses stratified bootstrap (resampling independently within each stratum) to |
170 | | - correctly estimate variance under covariate adaptive randomization (CAR). |
171 | | -
|
172 | | - Args: |
173 | | - target_treatment_arm (int): The index of the treatment arm of the treatment group. |
174 | | - control_treatment_arm (int): The index of the treatment arm of the control group. |
175 | | - quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. |
176 | | - alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. |
177 | | - n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. |
178 | | - display_progress (bool, optional): Whether to display a progress bar. Defaults to True. |
179 | | -
|
180 | | - Returns: |
181 | | - Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: |
182 | | - - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile |
183 | | - - Lower bounds (np.ndarray): Lower confidence interval bounds |
184 | | - - Upper bounds (np.ndarray): Upper confidence interval bounds |
185 | | - """ |
186 | | - if quantiles is None: |
187 | | - quantiles = np.arange(1, 10) / 10 |
188 | | - if np.any((quantiles <= 0) | (quantiles >= 1)): |
189 | | - raise ValueError("quantiles must be in the open interval (0, 1)") |
190 | | - |
191 | | - qte = self._compute_qtes( |
192 | | - target_treatment_arm, |
193 | | - control_treatment_arm, |
194 | | - quantiles, |
195 | | - self.covariates, |
196 | | - self.treatment_arms, |
197 | | - self.outcomes, |
198 | | - self.strata, |
199 | | - ) |
200 | | - |
201 | | - # Precompute stratum indices for stratified bootstrap |
202 | | - unique_strata = np.unique(self.strata) |
203 | | - strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} |
204 | | - |
205 | | - qtes = np.zeros((n_bootstrap, qte.shape[0])) |
206 | | - bootstrap_iter = range(n_bootstrap) |
207 | | - if display_progress: |
208 | | - bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") |
209 | | - for b in bootstrap_iter: |
210 | | - # Stratified bootstrap: resample within each stratum independently |
211 | | - bootstrap_indexes = np.concatenate([ |
212 | | - np.random.choice(idx, size=len(idx), replace=True) |
213 | | - for idx in strata_indices.values() |
214 | | - ]) |
215 | | - |
216 | | - qtes[b] = self._compute_qtes( |
217 | | - target_treatment_arm, |
218 | | - control_treatment_arm, |
219 | | - quantiles, |
220 | | - self.covariates[bootstrap_indexes], |
221 | | - self.treatment_arms[bootstrap_indexes], |
222 | | - self.outcomes[bootstrap_indexes], |
223 | | - self.strata[bootstrap_indexes], |
224 | | - ) |
225 | | - |
226 | | - qte_var = qtes.var(axis=0) |
227 | | - |
228 | | - qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) |
229 | | - qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) |
230 | | - |
231 | | - return qte, qte_lower, qte_upper |
232 | | - |
233 | 156 |
|
234 | 157 | class AdjustedStratifiedDistributionEstimator(DistributionEstimatorBase): |
235 | 158 | """A class is for estimating the adjusted distribution function and computing the Distributional parameters for CAR.""" |
@@ -482,82 +405,6 @@ def _compute_interval_probability( |
482 | 405 |
|
483 | 406 | return prediction.mean(axis=0), prediction, superset_prediction |
484 | 407 |
|
485 | | - def predict_qte( |
486 | | - self, |
487 | | - target_treatment_arm: int, |
488 | | - control_treatment_arm: int, |
489 | | - quantiles: Optional[np.ndarray] = None, |
490 | | - alpha: float = 0.05, |
491 | | - n_bootstrap=500, |
492 | | - display_progress: bool = True, |
493 | | - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
494 | | - """ |
495 | | - Compute Quantile Treatment Effects (QTE) using stratified bootstrap. |
496 | | -
|
497 | | - Uses stratified bootstrap (resampling independently within each stratum) to |
498 | | - correctly estimate variance under covariate adaptive randomization (CAR). |
499 | | -
|
500 | | - Args: |
501 | | - target_treatment_arm (int): The index of the treatment arm of the treatment group. |
502 | | - control_treatment_arm (int): The index of the treatment arm of the control group. |
503 | | - quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. |
504 | | - alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. |
505 | | - n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. |
506 | | - display_progress (bool, optional): Whether to display a progress bar. Defaults to True. |
507 | | -
|
508 | | - Returns: |
509 | | - Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: |
510 | | - - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile |
511 | | - - Lower bounds (np.ndarray): Lower confidence interval bounds |
512 | | - - Upper bounds (np.ndarray): Upper confidence interval bounds |
513 | | - """ |
514 | | - if quantiles is None: |
515 | | - quantiles = np.arange(1, 10) / 10 |
516 | | - if np.any((quantiles <= 0) | (quantiles >= 1)): |
517 | | - raise ValueError("quantiles must be in the open interval (0, 1)") |
518 | | - |
519 | | - qte = self._compute_qtes( |
520 | | - target_treatment_arm, |
521 | | - control_treatment_arm, |
522 | | - quantiles, |
523 | | - self.covariates, |
524 | | - self.treatment_arms, |
525 | | - self.outcomes, |
526 | | - self.strata, |
527 | | - ) |
528 | | - |
529 | | - # Precompute stratum indices for stratified bootstrap |
530 | | - unique_strata = np.unique(self.strata) |
531 | | - strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} |
532 | | - |
533 | | - qtes = np.zeros((n_bootstrap, qte.shape[0])) |
534 | | - bootstrap_iter = range(n_bootstrap) |
535 | | - if display_progress: |
536 | | - bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") |
537 | | - for b in bootstrap_iter: |
538 | | - # Stratified bootstrap: resample within each stratum independently |
539 | | - bootstrap_indexes = np.concatenate([ |
540 | | - np.random.choice(idx, size=len(idx), replace=True) |
541 | | - for idx in strata_indices.values() |
542 | | - ]) |
543 | | - |
544 | | - qtes[b] = self._compute_qtes( |
545 | | - target_treatment_arm, |
546 | | - control_treatment_arm, |
547 | | - quantiles, |
548 | | - self.covariates[bootstrap_indexes], |
549 | | - self.treatment_arms[bootstrap_indexes], |
550 | | - self.outcomes[bootstrap_indexes], |
551 | | - self.strata[bootstrap_indexes], |
552 | | - ) |
553 | | - |
554 | | - qte_var = qtes.var(axis=0) |
555 | | - |
556 | | - qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) |
557 | | - qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) |
558 | | - |
559 | | - return qte, qte_lower, qte_upper |
560 | | - |
561 | 408 | def _compute_model_prediction(self, model, covariates: np.ndarray) -> np.ndarray: |
562 | 409 | if hasattr(model, "predict_proba"): |
563 | 410 | if self.is_multi_task: |
|
0 commit comments