Skip to content

Commit 958f2b7

Browse files
committed
deploy: 29de8a0
1 parent 92d2449 commit 958f2b7

7 files changed

Lines changed: 113 additions & 81 deletions

File tree

_modules/dte_adj/local.html

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030
<div class="body" role="main">
3131

3232
<h1>Source code for dte_adj.local</h1><div class="highlight"><pre>
33-
<span></span><span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
33+
<span></span><span class="kn">from</span><span class="w"> </span><span class="nn">__future__</span><span class="w"> </span><span class="kn">import</span> <span class="n">annotations</span>
34+
35+
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
3436
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tuple</span>
3537
<span class="kn">from</span><span class="w"> </span><span class="nn">dte_adj.stratified</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span>
3638
<span class="n">SimpleStratifiedDistributionEstimator</span><span class="p">,</span>
3739
<span class="n">AdjustedStratifiedDistributionEstimator</span><span class="p">,</span>
3840
<span class="p">)</span>
39-
<span class="kn">from</span><span class="w"> </span><span class="nn">dte_adj.util</span><span class="w"> </span><span class="kn">import</span> <span class="n">compute_ldte</span><span class="p">,</span> <span class="n">compute_lpte</span>
41+
<span class="kn">from</span><span class="w"> </span><span class="nn">dte_adj.util</span><span class="w"> </span><span class="kn">import</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">compute_ldte</span><span class="p">,</span> <span class="n">compute_lpte</span><span class="p">,</span> <span class="n">_convert_to_ndarray</span>
4042

4143

4244
<div class="viewcode-block" id="SimpleLocalDistributionEstimator">
@@ -64,25 +66,26 @@ <h1>Source code for dte_adj.local</h1><div class="highlight"><pre>
6466
<a class="viewcode-back" href="../../api/local.html#dte_adj.SimpleLocalDistributionEstimator.fit">[docs]</a>
6567
<span class="k">def</span><span class="w"> </span><span class="nf">fit</span><span class="p">(</span>
6668
<span class="bp">self</span><span class="p">,</span>
67-
<span class="n">covariates</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
68-
<span class="n">treatment_arms</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
69-
<span class="n">treatment_indicator</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
70-
<span class="n">outcomes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
71-
<span class="n">strata</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
72-
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;SimpleLocalDistributionEstimator&quot;</span><span class="p">:</span>
69+
<span class="n">covariates</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
70+
<span class="n">treatment_arms</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
71+
<span class="n">treatment_indicator</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
72+
<span class="n">outcomes</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
73+
<span class="n">strata</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
74+
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SimpleLocalDistributionEstimator</span><span class="p">:</span>
7375
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
7476
<span class="sd"> Train the SimpleLocalDistributionEstimator.</span>
7577

7678
<span class="sd"> Args:</span>
77-
<span class="sd"> covariates (np.ndarray): Pre-treatment covariates.</span>
78-
<span class="sd"> treatment_arms (np.ndarray): Treatment assignment variable (Z).</span>
79-
<span class="sd"> treatment_indicator (np.ndarray): Treatment indicator variable (D).</span>
80-
<span class="sd"> outcomes (np.ndarray): Scalar-valued observed outcome.</span>
81-
<span class="sd"> strata (np.ndarray): Stratum indicators.</span>
79+
<span class="sd"> covariates: Pre-treatment covariates.</span>
80+
<span class="sd"> treatment_arms: Treatment assignment variable (Z).</span>
81+
<span class="sd"> treatment_indicator: Treatment indicator variable (D).</span>
82+
<span class="sd"> outcomes: Scalar-valued observed outcome.</span>
83+
<span class="sd"> strata: Stratum indicators.</span>
8284

8385
<span class="sd"> Returns:</span>
8486
<span class="sd"> SimpleLocalDistributionEstimator: The fitted estimator.</span>
8587
<span class="sd"> &quot;&quot;&quot;</span>
88+
<span class="n">treatment_indicator</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">treatment_indicator</span><span class="p">)</span>
8689
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">covariates</span><span class="p">,</span> <span class="n">treatment_arms</span><span class="p">,</span> <span class="n">outcomes</span><span class="p">,</span> <span class="n">strata</span><span class="p">)</span>
8790
<span class="bp">self</span><span class="o">.</span><span class="n">treatment_indicator</span> <span class="o">=</span> <span class="n">treatment_indicator</span>
8891

@@ -244,25 +247,26 @@ <h1>Source code for dte_adj.local</h1><div class="highlight"><pre>
244247
<a class="viewcode-back" href="../../api/local.html#dte_adj.AdjustedLocalDistributionEstimator.fit">[docs]</a>
245248
<span class="k">def</span><span class="w"> </span><span class="nf">fit</span><span class="p">(</span>
246249
<span class="bp">self</span><span class="p">,</span>
247-
<span class="n">covariates</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
248-
<span class="n">treatment_arms</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
249-
<span class="n">treatment_indicator</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
250-
<span class="n">outcomes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
251-
<span class="n">strata</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
252-
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;AdjustedLocalDistributionEstimator&quot;</span><span class="p">:</span>
250+
<span class="n">covariates</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
251+
<span class="n">treatment_arms</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
252+
<span class="n">treatment_indicator</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
253+
<span class="n">outcomes</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
254+
<span class="n">strata</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span>
255+
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">AdjustedLocalDistributionEstimator</span><span class="p">:</span>
253256
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
254257
<span class="sd"> Train the AdjustedLocalDistributionEstimator.</span>
255258

256259
<span class="sd"> Args:</span>
257-
<span class="sd"> covariates (np.ndarray): Pre-treatment covariates.</span>
258-
<span class="sd"> treatment_arms (np.ndarray): Treatment assignment variable (Z).</span>
259-
<span class="sd"> treatment_indicator (np.ndarray): Treatment indicator variable (D).</span>
260-
<span class="sd"> outcomes (np.ndarray): Scalar-valued observed outcome.</span>
261-
<span class="sd"> strata (np.ndarray): Stratum indicators.</span>
260+
<span class="sd"> covariates: Pre-treatment covariates.</span>
261+
<span class="sd"> treatment_arms: Treatment assignment variable (Z).</span>
262+
<span class="sd"> treatment_indicator: Treatment indicator variable (D).</span>
263+
<span class="sd"> outcomes: Scalar-valued observed outcome.</span>
264+
<span class="sd"> strata: Stratum indicators.</span>
262265

263266
<span class="sd"> Returns:</span>
264267
<span class="sd"> AdjustedLocalDistributionEstimator: The fitted estimator.</span>
265268
<span class="sd"> &quot;&quot;&quot;</span>
269+
<span class="n">treatment_indicator</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">treatment_indicator</span><span class="p">)</span>
266270
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">covariates</span><span class="p">,</span> <span class="n">treatment_arms</span><span class="p">,</span> <span class="n">outcomes</span><span class="p">,</span> <span class="n">strata</span><span class="p">)</span>
267271
<span class="bp">self</span><span class="o">.</span><span class="n">treatment_indicator</span> <span class="o">=</span> <span class="n">treatment_indicator</span>
268272

_modules/dte_adj/simple.html

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@
3030
<div class="body" role="main">
3131

3232
<h1>Source code for dte_adj.simple</h1><div class="highlight"><pre>
33-
<span></span><span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
33+
<span></span><span class="kn">from</span><span class="w"> </span><span class="nn">__future__</span><span class="w"> </span><span class="kn">import</span> <span class="n">annotations</span>
34+
35+
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
3436
<span class="kn">from</span><span class="w"> </span><span class="nn">dte_adj.stratified</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span>
3537
<span class="n">SimpleStratifiedDistributionEstimator</span><span class="p">,</span>
3638
<span class="n">AdjustedStratifiedDistributionEstimator</span><span class="p">,</span>
3739
<span class="p">)</span>
40+
<span class="kn">from</span><span class="w"> </span><span class="nn">dte_adj.util</span><span class="w"> </span><span class="kn">import</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">_convert_to_ndarray</span>
3841

3942

4043
<div class="viewcode-block" id="SimpleDistributionEstimator">
@@ -81,19 +84,23 @@ <h1>Source code for dte_adj.simple</h1><div class="highlight"><pre>
8184
<div class="viewcode-block" id="SimpleDistributionEstimator.fit">
8285
<a class="viewcode-back" href="../../api/simple.html#dte_adj.SimpleDistributionEstimator.fit">[docs]</a>
8386
<span class="k">def</span><span class="w"> </span><span class="nf">fit</span><span class="p">(</span>
84-
<span class="bp">self</span><span class="p">,</span> <span class="n">covariates</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">treatment_arms</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">outcomes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span>
85-
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;SimpleDistributionEstimator&quot;</span><span class="p">:</span>
87+
<span class="bp">self</span><span class="p">,</span> <span class="n">covariates</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">treatment_arms</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">outcomes</span><span class="p">:</span> <span class="n">ArrayLike</span>
88+
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SimpleDistributionEstimator</span><span class="p">:</span>
8689
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
8790
<span class="sd"> Set parameters.</span>
8891

8992
<span class="sd"> Args:</span>
90-
<span class="sd"> covariates (np.ndarray): Pre-treatment covariates.</span>
91-
<span class="sd"> treatment_arms (np.ndarray): The index of the treatment arm.</span>
92-
<span class="sd"> outcomes (np.ndarray): Scalar-valued observed outcome.</span>
93+
<span class="sd"> covariates: Pre-treatment covariates.</span>
94+
<span class="sd"> treatment_arms: The index of the treatment arm.</span>
95+
<span class="sd"> outcomes: Scalar-valued observed outcome.</span>
9396

9497
<span class="sd"> Returns:</span>
9598
<span class="sd"> SimpleDistributionEstimator: The fitted estimator.</span>
9699
<span class="sd"> &quot;&quot;&quot;</span>
100+
<span class="n">covariates</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">covariates</span><span class="p">)</span>
101+
<span class="n">treatment_arms</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">treatment_arms</span><span class="p">)</span>
102+
<span class="n">outcomes</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">outcomes</span><span class="p">)</span>
103+
97104
<span class="k">if</span> <span class="n">covariates</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">treatment_arms</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
98105
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The shape of covariates and treatment_arm should be same&quot;</span><span class="p">)</span>
99106

@@ -147,19 +154,23 @@ <h1>Source code for dte_adj.simple</h1><div class="highlight"><pre>
147154
<div class="viewcode-block" id="AdjustedDistributionEstimator.fit">
148155
<a class="viewcode-back" href="../../api/simple.html#dte_adj.AdjustedDistributionEstimator.fit">[docs]</a>
149156
<span class="k">def</span><span class="w"> </span><span class="nf">fit</span><span class="p">(</span>
150-
<span class="bp">self</span><span class="p">,</span> <span class="n">covariates</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">treatment_arms</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">outcomes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span>
151-
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;AdjustedDistributionEstimator&quot;</span><span class="p">:</span>
157+
<span class="bp">self</span><span class="p">,</span> <span class="n">covariates</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">treatment_arms</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">outcomes</span><span class="p">:</span> <span class="n">ArrayLike</span>
158+
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">AdjustedDistributionEstimator</span><span class="p">:</span>
152159
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
153160
<span class="sd"> Set parameters.</span>
154161

155162
<span class="sd"> Args:</span>
156-
<span class="sd"> covariates (np.ndarray): Pre-treatment covariates.</span>
157-
<span class="sd"> treatment_arms (np.ndarray): The index of the treatment arm.</span>
158-
<span class="sd"> outcomes (np.ndarray): Scalar-valued observed outcome.</span>
163+
<span class="sd"> covariates: Pre-treatment covariates.</span>
164+
<span class="sd"> treatment_arms: The index of the treatment arm.</span>
165+
<span class="sd"> outcomes: Scalar-valued observed outcome.</span>
159166

160167
<span class="sd"> Returns:</span>
161168
<span class="sd"> AdjustedDistributionEstimator: The fitted estimator.</span>
162169
<span class="sd"> &quot;&quot;&quot;</span>
170+
<span class="n">covariates</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">covariates</span><span class="p">)</span>
171+
<span class="n">treatment_arms</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">treatment_arms</span><span class="p">)</span>
172+
<span class="n">outcomes</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">outcomes</span><span class="p">)</span>
173+
163174
<span class="k">if</span> <span class="n">covariates</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">treatment_arms</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
164175
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The shape of covariates and treatment_arm should be same&quot;</span><span class="p">)</span>
165176

0 commit comments

Comments
 (0)