@@ -176,20 +176,52 @@ def predict_qte(
176176 n_bootstrap = 500 ,
177177 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
178178 """
179- Compute QTE based on the estimator for the distribution function.
179+ Compute Quantile Treatment Effects (QTE) based on the estimator for the distribution function.
180+
181+ The QTE measures the difference in quantiles between treatment groups, providing insights
182+ into how treatment affects different parts of the outcome distribution. For stratified
183+ estimators, the computation properly accounts for strata.
180184
181185 Args:
182186 target_treatment_arm (int): The index of the treatment arm of the treatment group.
183187 control_treatment_arm (int): The index of the treatment arm of the control group.
184- quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1 * i for i in range(1, 10) ].
188+ quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9 ].
185189 alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
186190 n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
187191
188192 Returns:
189193 Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
190- - Expected QTEs
191- - Upper bounds
192- - Lower bounds
194+ - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile
195+ - Lower bounds (np.ndarray): Lower confidence interval bounds
196+ - Upper bounds (np.ndarray): Upper confidence interval bounds
197+
198+ Example:
199+ .. code-block:: python
200+
201+ import numpy as np
202+ from dte_adj import SimpleStratifiedDistributionEstimator
203+
204+ # Generate stratified sample data
205+ X = np.random.randn(1000, 5)
206+ strata = np.random.choice([0, 1, 2], size=1000)
207+ D = np.random.binomial(1, 0.5, 1000)
208+ Y = X[:, 0] + 2 * D + 0.5 * strata + np.random.randn(1000)
209+
210+ # Fit stratified estimator
211+ estimator = SimpleStratifiedDistributionEstimator()
212+ estimator.fit(X, D, Y, strata)
213+
214+ # Compute QTE at specific quantiles
215+ quantiles = np.array([0.25, 0.5, 0.75]) # 25th, 50th, 75th percentiles
216+ qte, lower, upper = estimator.predict_qte(
217+ target_treatment_arm=1,
218+ control_treatment_arm=0,
219+ quantiles=quantiles,
220+ n_bootstrap=100
221+ )
222+
223+ print(f"QTE at quantiles {quantiles}: {qte}")
224+ print(f"Median effect (50th percentile): {qte[1]:.3f}")
193225 """
194226 qte = self ._compute_qtes (
195227 target_treatment_arm ,
@@ -198,20 +230,23 @@ def predict_qte(
198230 self .covariates ,
199231 self .treatment_arms ,
200232 self .outcomes ,
233+ self .strata ,
201234 )
202235 n_obs = len (self .outcomes )
203236 indexes = np .arange (n_obs )
204237
205238 qtes = np .zeros ((n_bootstrap , qte .shape [0 ]))
206239 for b in range (n_bootstrap ):
207240 bootstrap_indexes = np .random .choice (indexes , size = n_obs , replace = True )
241+
208242 qtes [b ] = self ._compute_qtes (
209243 target_treatment_arm ,
210244 control_treatment_arm ,
211245 quantiles ,
212246 self .covariates [bootstrap_indexes ],
213247 self .treatment_arms [bootstrap_indexes ],
214248 self .outcomes [bootstrap_indexes ],
249+ self .strata [bootstrap_indexes ],
215250 )
216251
217252 qte_var = qtes .var (axis = 0 )
@@ -333,7 +368,8 @@ def _compute_qtes(
333368 covariates : np .ndarray ,
334369 treatment_arms : np .ndarray ,
335370 outcomes : np .array ,
336- ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
371+ strata : np .ndarray ,
372+ ) -> np .ndarray :
337373 """Compute expected QTEs."""
338374 locations = np .sort (outcomes )
339375
@@ -342,13 +378,21 @@ def find_quantile(quantile, arm):
342378 result = - 1
343379 while low <= high :
344380 mid = (low + high ) // 2
381+ # Temporarily store original strata and use the provided strata
382+ original_strata = self .strata
383+ self .strata = strata
384+
345385 val , _ , _ = self ._compute_cumulative_distribution (
346386 arm ,
347387 np .full ((1 ), locations [mid ]),
348388 covariates ,
349389 treatment_arms ,
350390 outcomes ,
351391 )
392+
393+ # Restore original strata
394+ self .strata = original_strata
395+
352396 if val [0 ] <= quantile :
353397 result = locations [mid ]
354398 low = mid + 1
0 commit comments