1+ from __future__ import annotations
2+
13import numpy as np
24from typing import Tuple
35from dte_adj .stratified import (
46 SimpleStratifiedDistributionEstimator ,
57 AdjustedStratifiedDistributionEstimator ,
68)
7- from dte_adj .util import compute_ldte , compute_lpte , _convert_to_ndarray
9+ from dte_adj .util import ArrayLike , compute_ldte , compute_lpte , _convert_to_ndarray
810
911
1012class SimpleLocalDistributionEstimator (SimpleStratifiedDistributionEstimator ):
@@ -28,21 +30,21 @@ def __init__(self):
2830
2931 def fit (
3032 self ,
31- covariates : np . ndarray ,
32- treatment_arms : np . ndarray ,
33- treatment_indicator : np . ndarray ,
34- outcomes : np . ndarray ,
35- strata : np . ndarray ,
36- ) -> " SimpleLocalDistributionEstimator" :
33+ covariates : ArrayLike ,
34+ treatment_arms : ArrayLike ,
35+ treatment_indicator : ArrayLike ,
36+ outcomes : ArrayLike ,
37+ strata : ArrayLike ,
38+ ) -> SimpleLocalDistributionEstimator :
3739 """
3840 Train the SimpleLocalDistributionEstimator.
3941
4042 Args:
41- covariates (np.ndarray ): Pre-treatment covariates.
42- treatment_arms (np.ndarray ): Treatment assignment variable (Z).
43- treatment_indicator (np.ndarray ): Treatment indicator variable (D).
44- outcomes (np.ndarray ): Scalar-valued observed outcome.
45- strata (np.ndarray ): Stratum indicators.
43+ covariates (ArrayLike ): Pre-treatment covariates.
44+ treatment_arms (ArrayLike ): Treatment assignment variable (Z).
45+ treatment_indicator (ArrayLike ): Treatment indicator variable (D).
46+ outcomes (ArrayLike ): Scalar-valued observed outcome.
47+ strata (ArrayLike ): Stratum indicators.
4648
4749 Returns:
4850 SimpleLocalDistributionEstimator: The fitted estimator.
@@ -197,21 +199,21 @@ class AdjustedLocalDistributionEstimator(AdjustedStratifiedDistributionEstimator
197199
198200 def fit (
199201 self ,
200- covariates : np . ndarray ,
201- treatment_arms : np . ndarray ,
202- treatment_indicator : np . ndarray ,
203- outcomes : np . ndarray ,
204- strata : np . ndarray ,
205- ) -> " AdjustedLocalDistributionEstimator" :
202+ covariates : ArrayLike ,
203+ treatment_arms : ArrayLike ,
204+ treatment_indicator : ArrayLike ,
205+ outcomes : ArrayLike ,
206+ strata : ArrayLike ,
207+ ) -> AdjustedLocalDistributionEstimator :
206208 """
207209 Train the AdjustedLocalDistributionEstimator.
208210
209211 Args:
210- covariates (np.ndarray ): Pre-treatment covariates.
211- treatment_arms (np.ndarray ): Treatment assignment variable (Z).
212- treatment_indicator (np.ndarray ): Treatment indicator variable (D).
213- outcomes (np.ndarray ): Scalar-valued observed outcome.
214- strata (np.ndarray ): Stratum indicators.
212+ covariates (ArrayLike ): Pre-treatment covariates.
213+ treatment_arms (ArrayLike ): Treatment assignment variable (Z).
214+ treatment_indicator (ArrayLike ): Treatment indicator variable (D).
215+ outcomes (ArrayLike ): Scalar-valued observed outcome.
216+ strata (ArrayLike ): Stratum indicators.
215217
216218 Returns:
217219 AdjustedLocalDistributionEstimator: The fitted estimator.
0 commit comments