|
30 | 30 | <div class="body" role="main"> |
31 | 31 |
|
32 | 32 | <h1>Source code for dte_adj.local</h1><div class="highlight"><pre> |
33 | | -<span></span><span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> |
| 33 | +<span></span><span class="kn">from</span><span class="w"> </span><span class="nn">__future__</span><span class="w"> </span><span class="kn">import</span> <span class="n">annotations</span> |
| 34 | + |
| 35 | +<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> |
34 | 36 | <span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tuple</span> |
35 | 37 | <span class="kn">from</span><span class="w"> </span><span class="nn">dte_adj.stratified</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span> |
36 | 38 | <span class="n">SimpleStratifiedDistributionEstimator</span><span class="p">,</span> |
37 | 39 | <span class="n">AdjustedStratifiedDistributionEstimator</span><span class="p">,</span> |
38 | 40 | <span class="p">)</span> |
39 | | -<span class="kn">from</span><span class="w"> </span><span class="nn">dte_adj.util</span><span class="w"> </span><span class="kn">import</span> <span class="n">compute_ldte</span><span class="p">,</span> <span class="n">compute_lpte</span> |
| 41 | +<span class="kn">from</span><span class="w"> </span><span class="nn">dte_adj.util</span><span class="w"> </span><span class="kn">import</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">compute_ldte</span><span class="p">,</span> <span class="n">compute_lpte</span><span class="p">,</span> <span class="n">_convert_to_ndarray</span> |
40 | 42 |
|
41 | 43 |
|
42 | 44 | <div class="viewcode-block" id="SimpleLocalDistributionEstimator"> |
@@ -64,25 +66,26 @@ <h1>Source code for dte_adj.local</h1><div class="highlight"><pre> |
64 | 66 | <a class="viewcode-back" href="../../api/local.html#dte_adj.SimpleLocalDistributionEstimator.fit">[docs]</a> |
65 | 67 | <span class="k">def</span><span class="w"> </span><span class="nf">fit</span><span class="p">(</span> |
66 | 68 | <span class="bp">self</span><span class="p">,</span> |
67 | | - <span class="n">covariates</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
68 | | - <span class="n">treatment_arms</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
69 | | - <span class="n">treatment_indicator</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
70 | | - <span class="n">outcomes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
71 | | - <span class="n">strata</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
72 | | - <span class="p">)</span> <span class="o">-></span> <span class="s2">"SimpleLocalDistributionEstimator"</span><span class="p">:</span> |
| 69 | + <span class="n">covariates</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 70 | + <span class="n">treatment_arms</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 71 | + <span class="n">treatment_indicator</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 72 | + <span class="n">outcomes</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 73 | + <span class="n">strata</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 74 | + <span class="p">)</span> <span class="o">-></span> <span class="n">SimpleLocalDistributionEstimator</span><span class="p">:</span> |
73 | 75 | <span class="w"> </span><span class="sd">"""</span> |
74 | 76 | <span class="sd"> Train the SimpleLocalDistributionEstimator.</span> |
75 | 77 |
|
76 | 78 | <span class="sd"> Args:</span> |
77 | | -<span class="sd"> covariates (np.ndarray): Pre-treatment covariates.</span> |
78 | | -<span class="sd"> treatment_arms (np.ndarray): Treatment assignment variable (Z).</span> |
79 | | -<span class="sd"> treatment_indicator (np.ndarray): Treatment indicator variable (D).</span> |
80 | | -<span class="sd"> outcomes (np.ndarray): Scalar-valued observed outcome.</span> |
81 | | -<span class="sd"> strata (np.ndarray): Stratum indicators.</span> |
| 79 | +<span class="sd"> covariates: Pre-treatment covariates.</span> |
| 80 | +<span class="sd"> treatment_arms: Treatment assignment variable (Z).</span> |
| 81 | +<span class="sd"> treatment_indicator: Treatment indicator variable (D).</span> |
| 82 | +<span class="sd"> outcomes: Scalar-valued observed outcome.</span> |
| 83 | +<span class="sd"> strata: Stratum indicators.</span> |
82 | 84 |
|
83 | 85 | <span class="sd"> Returns:</span> |
84 | 86 | <span class="sd"> SimpleLocalDistributionEstimator: The fitted estimator.</span> |
85 | 87 | <span class="sd"> """</span> |
| 88 | + <span class="n">treatment_indicator</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">treatment_indicator</span><span class="p">)</span> |
86 | 89 | <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">covariates</span><span class="p">,</span> <span class="n">treatment_arms</span><span class="p">,</span> <span class="n">outcomes</span><span class="p">,</span> <span class="n">strata</span><span class="p">)</span> |
87 | 90 | <span class="bp">self</span><span class="o">.</span><span class="n">treatment_indicator</span> <span class="o">=</span> <span class="n">treatment_indicator</span> |
88 | 91 |
|
@@ -244,25 +247,26 @@ <h1>Source code for dte_adj.local</h1><div class="highlight"><pre> |
244 | 247 | <a class="viewcode-back" href="../../api/local.html#dte_adj.AdjustedLocalDistributionEstimator.fit">[docs]</a> |
245 | 248 | <span class="k">def</span><span class="w"> </span><span class="nf">fit</span><span class="p">(</span> |
246 | 249 | <span class="bp">self</span><span class="p">,</span> |
247 | | - <span class="n">covariates</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
248 | | - <span class="n">treatment_arms</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
249 | | - <span class="n">treatment_indicator</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
250 | | - <span class="n">outcomes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
251 | | - <span class="n">strata</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> |
252 | | - <span class="p">)</span> <span class="o">-></span> <span class="s2">"AdjustedLocalDistributionEstimator"</span><span class="p">:</span> |
| 250 | + <span class="n">covariates</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 251 | + <span class="n">treatment_arms</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 252 | + <span class="n">treatment_indicator</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 253 | + <span class="n">outcomes</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 254 | + <span class="n">strata</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> |
| 255 | + <span class="p">)</span> <span class="o">-></span> <span class="n">AdjustedLocalDistributionEstimator</span><span class="p">:</span> |
253 | 256 | <span class="w"> </span><span class="sd">"""</span> |
254 | 257 | <span class="sd"> Train the AdjustedLocalDistributionEstimator.</span> |
255 | 258 |
|
256 | 259 | <span class="sd"> Args:</span> |
257 | | -<span class="sd"> covariates (np.ndarray): Pre-treatment covariates.</span> |
258 | | -<span class="sd"> treatment_arms (np.ndarray): Treatment assignment variable (Z).</span> |
259 | | -<span class="sd"> treatment_indicator (np.ndarray): Treatment indicator variable (D).</span> |
260 | | -<span class="sd"> outcomes (np.ndarray): Scalar-valued observed outcome.</span> |
261 | | -<span class="sd"> strata (np.ndarray): Stratum indicators.</span> |
| 260 | +<span class="sd"> covariates: Pre-treatment covariates.</span> |
| 261 | +<span class="sd"> treatment_arms: Treatment assignment variable (Z).</span> |
| 262 | +<span class="sd"> treatment_indicator: Treatment indicator variable (D).</span> |
| 263 | +<span class="sd"> outcomes: Scalar-valued observed outcome.</span> |
| 264 | +<span class="sd"> strata: Stratum indicators.</span> |
262 | 265 |
|
263 | 266 | <span class="sd"> Returns:</span> |
264 | 267 | <span class="sd"> AdjustedLocalDistributionEstimator: The fitted estimator.</span> |
265 | 268 | <span class="sd"> """</span> |
| 269 | + <span class="n">treatment_indicator</span> <span class="o">=</span> <span class="n">_convert_to_ndarray</span><span class="p">(</span><span class="n">treatment_indicator</span><span class="p">)</span> |
266 | 270 | <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">covariates</span><span class="p">,</span> <span class="n">treatment_arms</span><span class="p">,</span> <span class="n">outcomes</span><span class="p">,</span> <span class="n">strata</span><span class="p">)</span> |
267 | 271 | <span class="bp">self</span><span class="o">.</span><span class="n">treatment_indicator</span> <span class="o">=</span> <span class="n">treatment_indicator</span> |
268 | 272 |
|
|
0 commit comments