Skip to content

Commit 5491a23

Browse files
authored
Fix qte (#44)
* fix qte * fix test
1 parent cd6beb8 commit 5491a23

2 files changed

Lines changed: 51 additions & 6 deletions

File tree

dte_adj/__init__.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,20 +176,52 @@ def predict_qte(
176176
n_bootstrap=500,
177177
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
178178
"""
179-
Compute QTE based on the estimator for the distribution function.
179+
Compute Quantile Treatment Effects (QTE) based on the estimator for the distribution function.
180+
181+
The QTE measures the difference in quantiles between treatment groups, providing insights
182+
into how treatment affects different parts of the outcome distribution. For stratified
183+
estimators, the computation properly accounts for strata.
180184
181185
Args:
182186
target_treatment_arm (int): The index of the treatment arm of the treatment group.
183187
control_treatment_arm (int): The index of the treatment arm of the control group.
184-
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1 * i for i in range(1, 10)].
188+
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
185189
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
186190
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
187191
188192
Returns:
189193
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
190-
- Expected QTEs
191-
- Upper bounds
192-
- Lower bounds
194+
- Expected QTEs (np.ndarray): Treatment effect estimates at each quantile
195+
- Lower bounds (np.ndarray): Lower confidence interval bounds
196+
- Upper bounds (np.ndarray): Upper confidence interval bounds
197+
198+
Example:
199+
.. code-block:: python
200+
201+
import numpy as np
202+
from dte_adj import SimpleStratifiedDistributionEstimator
203+
204+
# Generate stratified sample data
205+
X = np.random.randn(1000, 5)
206+
strata = np.random.choice([0, 1, 2], size=1000)
207+
D = np.random.binomial(1, 0.5, 1000)
208+
Y = X[:, 0] + 2 * D + 0.5 * strata + np.random.randn(1000)
209+
210+
# Fit stratified estimator
211+
estimator = SimpleStratifiedDistributionEstimator()
212+
estimator.fit(X, D, Y, strata)
213+
214+
# Compute QTE at specific quantiles
215+
quantiles = np.array([0.25, 0.5, 0.75]) # 25th, 50th, 75th percentiles
216+
qte, lower, upper = estimator.predict_qte(
217+
target_treatment_arm=1,
218+
control_treatment_arm=0,
219+
quantiles=quantiles,
220+
n_bootstrap=100
221+
)
222+
223+
print(f"QTE at quantiles {quantiles}: {qte}")
224+
print(f"Median effect (50th percentile): {qte[1]:.3f}")
193225
"""
194226
qte = self._compute_qtes(
195227
target_treatment_arm,
@@ -198,20 +230,23 @@ def predict_qte(
198230
self.covariates,
199231
self.treatment_arms,
200232
self.outcomes,
233+
self.strata,
201234
)
202235
n_obs = len(self.outcomes)
203236
indexes = np.arange(n_obs)
204237

205238
qtes = np.zeros((n_bootstrap, qte.shape[0]))
206239
for b in range(n_bootstrap):
207240
bootstrap_indexes = np.random.choice(indexes, size=n_obs, replace=True)
241+
208242
qtes[b] = self._compute_qtes(
209243
target_treatment_arm,
210244
control_treatment_arm,
211245
quantiles,
212246
self.covariates[bootstrap_indexes],
213247
self.treatment_arms[bootstrap_indexes],
214248
self.outcomes[bootstrap_indexes],
249+
self.strata[bootstrap_indexes],
215250
)
216251

217252
qte_var = qtes.var(axis=0)
@@ -333,7 +368,8 @@ def _compute_qtes(
333368
covariates: np.ndarray,
334369
treatment_arms: np.ndarray,
335370
outcomes: np.array,
336-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
371+
strata: np.ndarray,
372+
) -> np.ndarray:
337373
"""Compute expected QTEs."""
338374
locations = np.sort(outcomes)
339375

@@ -342,13 +378,21 @@ def find_quantile(quantile, arm):
342378
result = -1
343379
while low <= high:
344380
mid = (low + high) // 2
381+
# Temporarily store original strata and use the provided strata
382+
original_strata = self.strata
383+
self.strata = strata
384+
345385
val, _, _ = self._compute_cumulative_distribution(
346386
arm,
347387
np.full((1), locations[mid]),
348388
covariates,
349389
treatment_arms,
350390
outcomes,
351391
)
392+
393+
# Restore original strata
394+
self.strata = original_strata
395+
352396
if val[0] <= quantile:
353397
result = locations[mid]
354398
low = mid + 1

tests/test_distribution_estimator_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def fit(self, covariates, treatment_arms, outcomes):
6060
self.covariates = covariates
6161
self.treatment_arms = treatment_arms
6262
self.outcomes = outcomes
63+
self.strata = np.zeros(len(covariates)) # Mock strata for QTE testing
6364

6465
def _compute_cumulative_distribution(
6566
self,

0 commit comments

Comments
 (0)