Skip to content

Commit b4c441c

Browse files
committed
Fix docstrings of fit() methods
1 parent 01beb81 commit b4c441c

3 files changed

Lines changed: 24 additions & 24 deletions

File tree

dte_adj/local.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def fit(
4040
Train the SimpleLocalDistributionEstimator.
4141
4242
Args:
43-
covariates (ArrayLike): Pre-treatment covariates.
44-
treatment_arms (ArrayLike): Treatment assignment variable (Z).
45-
treatment_indicator (ArrayLike): Treatment indicator variable (D).
46-
outcomes (ArrayLike): Scalar-valued observed outcome.
47-
strata (ArrayLike): Stratum indicators.
43+
covariates: Pre-treatment covariates.
44+
treatment_arms: Treatment assignment variable (Z).
45+
treatment_indicator: Treatment indicator variable (D).
46+
outcomes: Scalar-valued observed outcome.
47+
strata: Stratum indicators.
4848
4949
Returns:
5050
SimpleLocalDistributionEstimator: The fitted estimator.
@@ -209,11 +209,11 @@ def fit(
209209
Train the AdjustedLocalDistributionEstimator.
210210
211211
Args:
212-
covariates (ArrayLike): Pre-treatment covariates.
213-
treatment_arms (ArrayLike): Treatment assignment variable (Z).
214-
treatment_indicator (ArrayLike): Treatment indicator variable (D).
215-
outcomes (ArrayLike): Scalar-valued observed outcome.
216-
strata (ArrayLike): Stratum indicators.
212+
covariates: Pre-treatment covariates.
213+
treatment_arms: Treatment assignment variable (Z).
214+
treatment_indicator: Treatment indicator variable (D).
215+
outcomes: Scalar-valued observed outcome.
216+
strata: Stratum indicators.
217217
218218
Returns:
219219
AdjustedLocalDistributionEstimator: The fitted estimator.

dte_adj/simple.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def fit(
5454
Set parameters.
5555
5656
Args:
57-
covariates (ArrayLike): Pre-treatment covariates.
58-
treatment_arms (ArrayLike): The index of the treatment arm.
59-
outcomes (ArrayLike): Scalar-valued observed outcome.
57+
covariates: Pre-treatment covariates.
58+
treatment_arms: The index of the treatment arm.
59+
outcomes: Scalar-valued observed outcome.
6060
6161
Returns:
6262
SimpleDistributionEstimator: The fitted estimator.
@@ -118,9 +118,9 @@ def fit(
118118
Set parameters.
119119
120120
Args:
121-
covariates (ArrayLike): Pre-treatment covariates.
122-
treatment_arms (ArrayLike): The index of the treatment arm.
123-
outcomes (ArrayLike): Scalar-valued observed outcome.
121+
covariates: Pre-treatment covariates.
122+
treatment_arms: The index of the treatment arm.
123+
outcomes: Scalar-valued observed outcome.
124124
125125
Returns:
126126
AdjustedDistributionEstimator: The fitted estimator.

dte_adj/stratified.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ def fit(
2121
Train the DistributionEstimatorBase.
2222
2323
Args:
24-
covariates (ArrayLike): Pre-treatment covariates.
25-
treatment_arms (ArrayLike): The index of the treatment arm.
26-
outcomes (ArrayLike): Scalar-valued observed outcome.
27-
strata (ArrayLike): Stratum indicators.
24+
covariates: Pre-treatment covariates.
25+
treatment_arms: The index of the treatment arm.
26+
outcomes: Scalar-valued observed outcome.
27+
strata: Stratum indicators.
2828
2929
Returns:
3030
DistributionEstimatorBase: The fitted estimator.
@@ -186,10 +186,10 @@ def fit(
186186
Train the DistributionEstimatorBase.
187187
188188
Args:
189-
covariates (ArrayLike): Pre-treatment covariates.
190-
treatment_arms (ArrayLike): The index of the treatment arm.
191-
outcomes (ArrayLike): Scalar-valued observed outcome.
192-
strata (ArrayLike): Stratum indicators.
189+
covariates: Pre-treatment covariates.
190+
treatment_arms: The index of the treatment arm.
191+
outcomes: Scalar-valued observed outcome.
192+
strata: Stratum indicators.
193193
194194
Returns:
195195
DistributionEstimatorBase: The fitted estimator.

0 commit comments

Comments
 (0)