@@ -209,20 +209,52 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
209209 < span class ="n "> n_bootstrap</ span > < span class ="o "> =</ span > < span class ="mi "> 500</ span > < span class ="p "> ,</ span >
210210 < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Tuple</ span > < span class ="p "> [</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ]:</ span >
211211< span class ="w "> </ span > < span class ="sd "> """</ span >
212- < span class ="sd "> Compute QTE based on the estimator for the distribution function.</ span >
212+ < span class ="sd "> Compute Quantile Treatment Effects (QTE) based on the estimator for the distribution function.</ span >
213+
214+ < span class ="sd "> The QTE measures the difference in quantiles between treatment groups, providing insights</ span >
215+ < span class ="sd "> into how treatment affects different parts of the outcome distribution. For stratified</ span >
216+ < span class ="sd "> estimators, the computation properly accounts for strata.</ span >
213217
214218< span class ="sd "> Args:</ span >
215219< span class ="sd "> target_treatment_arm (int): The index of the treatment arm of the treatment group.</ span >
216220< span class ="sd "> control_treatment_arm (int): The index of the treatment arm of the control group.</ span >
217- < span class ="sd "> quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1 * i for i in range(1, 10) ].</ span >
221+ < span class ="sd "> quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9 ].</ span >
218222< span class ="sd "> alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.</ span >
219223< span class ="sd "> n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.</ span >
220224
221225< span class ="sd "> Returns:</ span >
222226< span class ="sd "> Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:</ span >
223- < span class ="sd "> - Expected QTEs</ span >
224- < span class ="sd "> - Upper bounds</ span >
225- < span class ="sd "> - Lower bounds</ span >
227+ < span class ="sd "> - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile</ span >
228+ < span class ="sd "> - Lower bounds (np.ndarray): Lower confidence interval bounds</ span >
229+ < span class ="sd "> - Upper bounds (np.ndarray): Upper confidence interval bounds</ span >
230+
231+ < span class ="sd "> Example:</ span >
232+ < span class ="sd "> .. code-block:: python</ span >
233+
234+ < span class ="sd "> import numpy as np</ span >
235+ < span class ="sd "> from dte_adj import SimpleStratifiedDistributionEstimator</ span >
236+ < span class ="sd "> </ span >
237+ < span class ="sd "> # Generate stratified sample data</ span >
238+ < span class ="sd "> X = np.random.randn(1000, 5)</ span >
239+ < span class ="sd "> strata = np.random.choice([0, 1, 2], size=1000)</ span >
240+ < span class ="sd "> D = np.random.binomial(1, 0.5, 1000)</ span >
241+ < span class ="sd "> Y = X[:, 0] + 2 * D + 0.5 * strata + np.random.randn(1000)</ span >
242+ < span class ="sd "> </ span >
243+ < span class ="sd "> # Fit stratified estimator</ span >
244+ < span class ="sd "> estimator = SimpleStratifiedDistributionEstimator()</ span >
245+ < span class ="sd "> estimator.fit(X, D, Y, strata)</ span >
246+ < span class ="sd "> </ span >
247+ < span class ="sd "> # Compute QTE at specific quantiles</ span >
248+ < span class ="sd "> quantiles = np.array([0.25, 0.5, 0.75]) # 25th, 50th, 75th percentiles</ span >
249+ < span class ="sd "> qte, lower, upper = estimator.predict_qte(</ span >
250+ < span class ="sd "> target_treatment_arm=1,</ span >
251+ < span class ="sd "> control_treatment_arm=0,</ span >
252+ < span class ="sd "> quantiles=quantiles,</ span >
253+ < span class ="sd "> n_bootstrap=100</ span >
254+ < span class ="sd "> )</ span >
255+ < span class ="sd "> </ span >
256+ < span class ="sd "> print(f"QTE at quantiles {quantiles}: {qte}")</ span >
257+ < span class ="sd "> print(f"Median effect (50th percentile): {qte[1]:.3f}")</ span >
226258< span class ="sd "> """</ span >
227259 < span class ="n "> qte</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _compute_qtes</ span > < span class ="p "> (</ span >
228260 < span class ="n "> target_treatment_arm</ span > < span class ="p "> ,</ span >
@@ -231,20 +263,23 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
231263 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> covariates</ span > < span class ="p "> ,</ span >
232264 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> treatment_arms</ span > < span class ="p "> ,</ span >
233265 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> outcomes</ span > < span class ="p "> ,</ span >
266+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> strata</ span > < span class ="p "> ,</ span >
234267 < span class ="p "> )</ span >
235268 < span class ="n "> n_obs</ span > < span class ="o "> =</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> outcomes</ span > < span class ="p "> )</ span >
236269 < span class ="n "> indexes</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> arange</ span > < span class ="p "> (</ span > < span class ="n "> n_obs</ span > < span class ="p "> )</ span >
237270
238271 < span class ="n "> qtes</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> ((</ span > < span class ="n "> n_bootstrap</ span > < span class ="p "> ,</ span > < span class ="n "> qte</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]))</ span >
239272 < span class ="k "> for</ span > < span class ="n "> b</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> n_bootstrap</ span > < span class ="p "> ):</ span >
240273 < span class ="n "> bootstrap_indexes</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> random</ span > < span class ="o "> .</ span > < span class ="n "> choice</ span > < span class ="p "> (</ span > < span class ="n "> indexes</ span > < span class ="p "> ,</ span > < span class ="n "> size</ span > < span class ="o "> =</ span > < span class ="n "> n_obs</ span > < span class ="p "> ,</ span > < span class ="n "> replace</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
274+
241275 < span class ="n "> qtes</ span > < span class ="p "> [</ span > < span class ="n "> b</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _compute_qtes</ span > < span class ="p "> (</ span >
242276 < span class ="n "> target_treatment_arm</ span > < span class ="p "> ,</ span >
243277 < span class ="n "> control_treatment_arm</ span > < span class ="p "> ,</ span >
244278 < span class ="n "> quantiles</ span > < span class ="p "> ,</ span >
245279 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> covariates</ span > < span class ="p "> [</ span > < span class ="n "> bootstrap_indexes</ span > < span class ="p "> ],</ span >
246280 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> treatment_arms</ span > < span class ="p "> [</ span > < span class ="n "> bootstrap_indexes</ span > < span class ="p "> ],</ span >
247281 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> outcomes</ span > < span class ="p "> [</ span > < span class ="n "> bootstrap_indexes</ span > < span class ="p "> ],</ span >
282+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> strata</ span > < span class ="p "> [</ span > < span class ="n "> bootstrap_indexes</ span > < span class ="p "> ],</ span >
248283 < span class ="p "> )</ span >
249284
250285 < span class ="n "> qte_var</ span > < span class ="o "> =</ span > < span class ="n "> qtes</ span > < span class ="o "> .</ span > < span class ="n "> var</ span > < span class ="p "> (</ span > < span class ="n "> axis</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span > < span class ="p "> )</ span >
@@ -366,7 +401,8 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
366401 < span class ="n "> covariates</ span > < span class ="p "> :</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span >
367402 < span class ="n "> treatment_arms</ span > < span class ="p "> :</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span >
368403 < span class ="n "> outcomes</ span > < span class ="p "> :</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> array</ span > < span class ="p "> ,</ span >
369- < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Tuple</ span > < span class ="p "> [</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ]:</ span >
404+ < span class ="n "> strata</ span > < span class ="p "> :</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span >
405+ < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> :</ span >
370406< span class ="w "> </ span > < span class ="sd "> """Compute expected QTEs."""</ span >
371407 < span class ="n "> locations</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> sort</ span > < span class ="p "> (</ span > < span class ="n "> outcomes</ span > < span class ="p "> )</ span >
372408
@@ -375,13 +411,21 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
375411 < span class ="n "> result</ span > < span class ="o "> =</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span >
376412 < span class ="k "> while</ span > < span class ="n "> low</ span > < span class ="o "> <=</ span > < span class ="n "> high</ span > < span class ="p "> :</ span >
377413 < span class ="n "> mid</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> low</ span > < span class ="o "> +</ span > < span class ="n "> high</ span > < span class ="p "> )</ span > < span class ="o "> //</ span > < span class ="mi "> 2</ span >
414+ < span class ="c1 "> # Temporarily store original strata and use the provided strata</ span >
415+ < span class ="n "> original_strata</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> strata</ span >
416+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> strata</ span > < span class ="o "> =</ span > < span class ="n "> strata</ span >
417+
378418 < span class ="n "> val</ span > < span class ="p "> ,</ span > < span class ="n "> _</ span > < span class ="p "> ,</ span > < span class ="n "> _</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _compute_cumulative_distribution</ span > < span class ="p "> (</ span >
379419 < span class ="n "> arm</ span > < span class ="p "> ,</ span >
380420 < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> full</ span > < span class ="p "> ((</ span > < span class ="mi "> 1</ span > < span class ="p "> ),</ span > < span class ="n "> locations</ span > < span class ="p "> [</ span > < span class ="n "> mid</ span > < span class ="p "> ]),</ span >
381421 < span class ="n "> covariates</ span > < span class ="p "> ,</ span >
382422 < span class ="n "> treatment_arms</ span > < span class ="p "> ,</ span >
383423 < span class ="n "> outcomes</ span > < span class ="p "> ,</ span >
384424 < span class ="p "> )</ span >
425+
426+ < span class ="c1 "> # Restore original strata</ span >
427+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> strata</ span > < span class ="o "> =</ span > < span class ="n "> original_strata</ span >
428+
385429 < span class ="k "> if</ span > < span class ="n "> val</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span > < span class ="o "> <=</ span > < span class ="n "> quantile</ span > < span class ="p "> :</ span >
386430 < span class ="n "> result</ span > < span class ="o "> =</ span > < span class ="n "> locations</ span > < span class ="p "> [</ span > < span class ="n "> mid</ span > < span class ="p "> ]</ span >
387431 < span class ="n "> low</ span > < span class ="o "> =</ span > < span class ="n "> mid</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span >
0 commit comments