Skip to content

Commit d62a65e

Browse files
committed
deploy: 5491a23
1 parent 518f35e commit d62a65e

3 files changed

Lines changed: 261 additions & 37 deletions

File tree

_modules/dte_adj.html

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,52 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
209209
<span class="n">n_bootstrap</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span>
210210
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]:</span>
211211
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
212-
<span class="sd"> Compute QTE based on the estimator for the distribution function.</span>
212+
<span class="sd"> Compute Quantile Treatment Effects (QTE) based on the estimator for the distribution function.</span>
213+
214+
<span class="sd"> The QTE measures the difference in quantiles between treatment groups, providing insights</span>
215+
<span class="sd"> into how treatment affects different parts of the outcome distribution. For stratified</span>
216+
<span class="sd"> estimators, the computation properly accounts for strata.</span>
213217

214218
<span class="sd"> Args:</span>
215219
<span class="sd"> target_treatment_arm (int): The index of the treatment arm of the treatment group.</span>
216220
<span class="sd"> control_treatment_arm (int): The index of the treatment arm of the control group.</span>
217-
<span class="sd"> quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1 * i for i in range(1, 10)].</span>
221+
<span class="sd"> quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].</span>
218222
<span class="sd"> alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.</span>
219223
<span class="sd"> n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.</span>
220224

221225
<span class="sd"> Returns:</span>
222226
<span class="sd"> Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:</span>
223-
<span class="sd"> - Expected QTEs</span>
224-
<span class="sd"> - Upper bounds</span>
225-
<span class="sd"> - Lower bounds</span>
227+
<span class="sd"> - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile</span>
228+
<span class="sd"> - Lower bounds (np.ndarray): Lower confidence interval bounds</span>
229+
<span class="sd"> - Upper bounds (np.ndarray): Upper confidence interval bounds</span>
230+
231+
<span class="sd"> Example:</span>
232+
<span class="sd"> .. code-block:: python</span>
233+
234+
<span class="sd"> import numpy as np</span>
235+
<span class="sd"> from dte_adj import SimpleStratifiedDistributionEstimator</span>
236+
<span class="sd"> </span>
237+
<span class="sd"> # Generate stratified sample data</span>
238+
<span class="sd"> X = np.random.randn(1000, 5)</span>
239+
<span class="sd"> strata = np.random.choice([0, 1, 2], size=1000)</span>
240+
<span class="sd"> D = np.random.binomial(1, 0.5, 1000)</span>
241+
<span class="sd"> Y = X[:, 0] + 2 * D + 0.5 * strata + np.random.randn(1000)</span>
242+
<span class="sd"> </span>
243+
<span class="sd"> # Fit stratified estimator</span>
244+
<span class="sd"> estimator = SimpleStratifiedDistributionEstimator()</span>
245+
<span class="sd"> estimator.fit(X, D, Y, strata)</span>
246+
<span class="sd"> </span>
247+
<span class="sd"> # Compute QTE at specific quantiles</span>
248+
<span class="sd"> quantiles = np.array([0.25, 0.5, 0.75]) # 25th, 50th, 75th percentiles</span>
249+
<span class="sd"> qte, lower, upper = estimator.predict_qte(</span>
250+
<span class="sd"> target_treatment_arm=1,</span>
251+
<span class="sd"> control_treatment_arm=0,</span>
252+
<span class="sd"> quantiles=quantiles,</span>
253+
<span class="sd"> n_bootstrap=100</span>
254+
<span class="sd"> )</span>
255+
<span class="sd"> </span>
256+
<span class="sd"> print(f&quot;QTE at quantiles {quantiles}: {qte}&quot;)</span>
257+
<span class="sd"> print(f&quot;Median effect (50th percentile): {qte[1]:.3f}&quot;)</span>
226258
<span class="sd"> &quot;&quot;&quot;</span>
227259
<span class="n">qte</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_qtes</span><span class="p">(</span>
228260
<span class="n">target_treatment_arm</span><span class="p">,</span>
@@ -231,20 +263,23 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
231263
<span class="bp">self</span><span class="o">.</span><span class="n">covariates</span><span class="p">,</span>
232264
<span class="bp">self</span><span class="o">.</span><span class="n">treatment_arms</span><span class="p">,</span>
233265
<span class="bp">self</span><span class="o">.</span><span class="n">outcomes</span><span class="p">,</span>
266+
<span class="bp">self</span><span class="o">.</span><span class="n">strata</span><span class="p">,</span>
234267
<span class="p">)</span>
235268
<span class="n">n_obs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">outcomes</span><span class="p">)</span>
236269
<span class="n">indexes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n_obs</span><span class="p">)</span>
237270

238271
<span class="n">qtes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_bootstrap</span><span class="p">,</span> <span class="n">qte</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
239272
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_bootstrap</span><span class="p">):</span>
240273
<span class="n">bootstrap_indexes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">indexes</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">n_obs</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
274+
241275
<span class="n">qtes</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_qtes</span><span class="p">(</span>
242276
<span class="n">target_treatment_arm</span><span class="p">,</span>
243277
<span class="n">control_treatment_arm</span><span class="p">,</span>
244278
<span class="n">quantiles</span><span class="p">,</span>
245279
<span class="bp">self</span><span class="o">.</span><span class="n">covariates</span><span class="p">[</span><span class="n">bootstrap_indexes</span><span class="p">],</span>
246280
<span class="bp">self</span><span class="o">.</span><span class="n">treatment_arms</span><span class="p">[</span><span class="n">bootstrap_indexes</span><span class="p">],</span>
247281
<span class="bp">self</span><span class="o">.</span><span class="n">outcomes</span><span class="p">[</span><span class="n">bootstrap_indexes</span><span class="p">],</span>
282+
<span class="bp">self</span><span class="o">.</span><span class="n">strata</span><span class="p">[</span><span class="n">bootstrap_indexes</span><span class="p">],</span>
248283
<span class="p">)</span>
249284

250285
<span class="n">qte_var</span> <span class="o">=</span> <span class="n">qtes</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
@@ -366,7 +401,8 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
366401
<span class="n">covariates</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
367402
<span class="n">treatment_arms</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
368403
<span class="n">outcomes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span>
369-
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]:</span>
404+
<span class="n">strata</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
405+
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
370406
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Compute expected QTEs.&quot;&quot;&quot;</span>
371407
<span class="n">locations</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">outcomes</span><span class="p">)</span>
372408

@@ -375,13 +411,21 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
375411
<span class="n">result</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
376412
<span class="k">while</span> <span class="n">low</span> <span class="o">&lt;=</span> <span class="n">high</span><span class="p">:</span>
377413
<span class="n">mid</span> <span class="o">=</span> <span class="p">(</span><span class="n">low</span> <span class="o">+</span> <span class="n">high</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span>
414+
<span class="c1"># Temporarily store original strata and use the provided strata</span>
415+
<span class="n">original_strata</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">strata</span>
416+
<span class="bp">self</span><span class="o">.</span><span class="n">strata</span> <span class="o">=</span> <span class="n">strata</span>
417+
378418
<span class="n">val</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_cumulative_distribution</span><span class="p">(</span>
379419
<span class="n">arm</span><span class="p">,</span>
380420
<span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="mi">1</span><span class="p">),</span> <span class="n">locations</span><span class="p">[</span><span class="n">mid</span><span class="p">]),</span>
381421
<span class="n">covariates</span><span class="p">,</span>
382422
<span class="n">treatment_arms</span><span class="p">,</span>
383423
<span class="n">outcomes</span><span class="p">,</span>
384424
<span class="p">)</span>
425+
426+
<span class="c1"># Restore original strata</span>
427+
<span class="bp">self</span><span class="o">.</span><span class="n">strata</span> <span class="o">=</span> <span class="n">original_strata</span>
428+
385429
<span class="k">if</span> <span class="n">val</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="n">quantile</span><span class="p">:</span>
386430
<span class="n">result</span> <span class="o">=</span> <span class="n">locations</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span>
387431
<span class="n">low</span> <span class="o">=</span> <span class="n">mid</span> <span class="o">+</span> <span class="mi">1</span>

0 commit comments

Comments
 (0)