Skip to content

Commit 931d926

Browse files
committed
Fix type hints
1 parent a774b20 commit 931d926

6 files changed

Lines changed: 117 additions & 58 deletions

File tree

dte_adj/local.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from typing import Tuple
35
from dte_adj.stratified import (
46
SimpleStratifiedDistributionEstimator,
57
AdjustedStratifiedDistributionEstimator,
68
)
7-
from dte_adj.util import compute_ldte, compute_lpte, _convert_to_ndarray
9+
from dte_adj.util import ArrayLike, compute_ldte, compute_lpte, _convert_to_ndarray
810

911

1012
class SimpleLocalDistributionEstimator(SimpleStratifiedDistributionEstimator):
@@ -28,21 +30,21 @@ def __init__(self):
2830

2931
def fit(
3032
self,
31-
covariates: np.ndarray,
32-
treatment_arms: np.ndarray,
33-
treatment_indicator: np.ndarray,
34-
outcomes: np.ndarray,
35-
strata: np.ndarray,
36-
) -> "SimpleLocalDistributionEstimator":
33+
covariates: ArrayLike,
34+
treatment_arms: ArrayLike,
35+
treatment_indicator: ArrayLike,
36+
outcomes: ArrayLike,
37+
strata: ArrayLike,
38+
) -> SimpleLocalDistributionEstimator:
3739
"""
3840
Train the SimpleLocalDistributionEstimator.
3941
4042
Args:
41-
covariates (np.ndarray): Pre-treatment covariates.
42-
treatment_arms (np.ndarray): Treatment assignment variable (Z).
43-
treatment_indicator (np.ndarray): Treatment indicator variable (D).
44-
outcomes (np.ndarray): Scalar-valued observed outcome.
45-
strata (np.ndarray): Stratum indicators.
43+
covariates (ArrayLike): Pre-treatment covariates.
44+
treatment_arms (ArrayLike): Treatment assignment variable (Z).
45+
treatment_indicator (ArrayLike): Treatment indicator variable (D).
46+
outcomes (ArrayLike): Scalar-valued observed outcome.
47+
strata (ArrayLike): Stratum indicators.
4648
4749
Returns:
4850
SimpleLocalDistributionEstimator: The fitted estimator.
@@ -197,21 +199,21 @@ class AdjustedLocalDistributionEstimator(AdjustedStratifiedDistributionEstimator
197199

198200
def fit(
199201
self,
200-
covariates: np.ndarray,
201-
treatment_arms: np.ndarray,
202-
treatment_indicator: np.ndarray,
203-
outcomes: np.ndarray,
204-
strata: np.ndarray,
205-
) -> "AdjustedLocalDistributionEstimator":
202+
covariates: ArrayLike,
203+
treatment_arms: ArrayLike,
204+
treatment_indicator: ArrayLike,
205+
outcomes: ArrayLike,
206+
strata: ArrayLike,
207+
) -> AdjustedLocalDistributionEstimator:
206208
"""
207209
Train the AdjustedLocalDistributionEstimator.
208210
209211
Args:
210-
covariates (np.ndarray): Pre-treatment covariates.
211-
treatment_arms (np.ndarray): Treatment assignment variable (Z).
212-
treatment_indicator (np.ndarray): Treatment indicator variable (D).
213-
outcomes (np.ndarray): Scalar-valued observed outcome.
214-
strata (np.ndarray): Stratum indicators.
212+
covariates (ArrayLike): Pre-treatment covariates.
213+
treatment_arms (ArrayLike): Treatment assignment variable (Z).
214+
treatment_indicator (ArrayLike): Treatment indicator variable (D).
215+
outcomes (ArrayLike): Scalar-valued observed outcome.
216+
strata (ArrayLike): Stratum indicators.
215217
216218
Returns:
217219
AdjustedLocalDistributionEstimator: The fitted estimator.

dte_adj/simple.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from dte_adj.stratified import (
35
SimpleStratifiedDistributionEstimator,
46
AdjustedStratifiedDistributionEstimator,
57
)
6-
from dte_adj.util import _convert_to_ndarray
8+
from dte_adj.util import ArrayLike, _convert_to_ndarray
79

810

911
class SimpleDistributionEstimator(SimpleStratifiedDistributionEstimator):
@@ -46,15 +48,15 @@ def __init__(self):
4648
super().__init__()
4749

4850
def fit(
49-
self, covariates: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
50-
) -> "SimpleDistributionEstimator":
51+
self, covariates: ArrayLike, treatment_arms: ArrayLike, outcomes: ArrayLike
52+
) -> SimpleDistributionEstimator:
5153
"""
5254
Set parameters.
5355
5456
Args:
55-
covariates (np.ndarray): Pre-treatment covariates.
56-
treatment_arms (np.ndarray): The index of the treatment arm.
57-
outcomes (np.ndarray): Scalar-valued observed outcome.
57+
covariates (ArrayLike): Pre-treatment covariates.
58+
treatment_arms (ArrayLike): The index of the treatment arm.
59+
outcomes (ArrayLike): Scalar-valued observed outcome.
5860
5961
Returns:
6062
SimpleDistributionEstimator: The fitted estimator.
@@ -110,15 +112,15 @@ class AdjustedDistributionEstimator(AdjustedStratifiedDistributionEstimator):
110112
"""
111113

112114
def fit(
113-
self, covariates: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
114-
) -> "AdjustedDistributionEstimator":
115+
self, covariates: ArrayLike, treatment_arms: ArrayLike, outcomes: ArrayLike
116+
) -> AdjustedDistributionEstimator:
115117
"""
116118
Set parameters.
117119
118120
Args:
119-
covariates (np.ndarray): Pre-treatment covariates.
120-
treatment_arms (np.ndarray): The index of the treatment arm.
121-
outcomes (np.ndarray): Scalar-valued observed outcome.
121+
covariates (ArrayLike): Pre-treatment covariates.
122+
treatment_arms (ArrayLike): The index of the treatment arm.
123+
outcomes (ArrayLike): Scalar-valued observed outcome.
122124
123125
Returns:
124126
AdjustedDistributionEstimator: The fitted estimator.

dte_adj/stratified.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from typing import Tuple, Any
35
from copy import deepcopy
46
from dte_adj.base import DistributionEstimatorBase
5-
from dte_adj.util import _convert_to_ndarray
7+
from dte_adj.util import ArrayLike, _convert_to_ndarray
68

79

810
class SimpleStratifiedDistributionEstimator(DistributionEstimatorBase):
911
"""A class is for estimating the empirical distribution function and computing the Distributional parameters for CAR."""
1012

1113
def fit(
1214
self,
13-
covariates: np.ndarray,
14-
treatment_arms: np.ndarray,
15-
outcomes: np.ndarray,
16-
strata: np.ndarray,
17-
) -> "DistributionEstimatorBase":
15+
covariates: ArrayLike,
16+
treatment_arms: ArrayLike,
17+
outcomes: ArrayLike,
18+
strata: ArrayLike,
19+
) -> DistributionEstimatorBase:
1820
"""
1921
Train the DistributionEstimatorBase.
2022
2123
Args:
22-
covariates (np.ndarray): Pre-treatment covariates.
23-
treatment_arms (np.ndarray): The index of the treatment arm.
24-
outcomes (np.ndarray): Scalar-valued observed outcome.
24+
covariates (ArrayLike): Pre-treatment covariates.
25+
treatment_arms (ArrayLike): The index of the treatment arm.
26+
outcomes (ArrayLike): Scalar-valued observed outcome.
27+
strata (ArrayLike): Stratum indicators.
2528
2629
Returns:
2730
DistributionEstimatorBase: The fitted estimator.
@@ -174,18 +177,19 @@ def __init__(self, base_model: Any, folds=3, is_multi_task=False):
174177

175178
def fit(
176179
self,
177-
covariates: np.ndarray,
178-
treatment_arms: np.ndarray,
179-
outcomes: np.ndarray,
180-
strata: np.ndarray,
181-
) -> "DistributionEstimatorBase":
180+
covariates: ArrayLike,
181+
treatment_arms: ArrayLike,
182+
outcomes: ArrayLike,
183+
strata: ArrayLike,
184+
) -> DistributionEstimatorBase:
182185
"""
183186
Train the DistributionEstimatorBase.
184187
185188
Args:
186-
covariates (np.ndarray): Pre-treatment covariates.
187-
treatment_arms (np.ndarray): The index of the treatment arm.
188-
outcomes (np.ndarray): Scalar-valued observed outcome.
189+
covariates (ArrayLike): Pre-treatment covariates.
190+
treatment_arms (ArrayLike): The index of the treatment arm.
191+
outcomes (ArrayLike): Scalar-valued observed outcome.
192+
strata (ArrayLike): Stratum indicators.
189193
190194
Returns:
191195
DistributionEstimatorBase: The fitted estimator.

dte_adj/util.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from scipy.stats import norm
3-
from typing import Tuple, TYPE_CHECKING
5+
from typing import Tuple, Union, TYPE_CHECKING
46

57
if TYPE_CHECKING:
8+
import pandas as pd
9+
import polars as pl
10+
611
from dte_adj.local import (
712
SimpleStratifiedDistributionEstimator,
813
AdjustedLocalDistributionEstimator,
914
)
1015

11-
12-
def _convert_to_ndarray(data: object) -> np.ndarray:
13-
"""Convert pd.Series or pd.DataFrame to np.ndarray if needed."""
16+
ArrayLike = Union[
17+
np.ndarray,
18+
list,
19+
tuple,
20+
pd.DataFrame,
21+
pd.Series,
22+
pl.DataFrame,
23+
pl.Series,
24+
]
25+
26+
27+
def _convert_to_ndarray(data: ArrayLike) -> np.ndarray:
28+
"""Convert array-like data to np.ndarray if needed."""
29+
if isinstance(data, np.ndarray):
30+
return data
1431
if hasattr(data, "to_numpy"):
1532
return data.to_numpy()
16-
return data
33+
return np.asarray(data)
1734

1835

1936
def compute_confidence_intervals(

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ dev = [
3131
"sphinx>=7.3.7,<8.2.0",
3232
"scikit-learn>=1.5,<1.9",
3333
"pre-commit>=4.0.1,<4.6.0",
34-
"pandas>=2.0"
34+
"pandas>=2.0",
35+
"polars>=1.0"
3536
]
3637

3738
[tool.setuptools.packages.find]
@@ -49,7 +50,8 @@ dev-dependencies = [
4950
"sphinx>=7.3.7,<8.2.0",
5051
"scikit-learn>=1.5,<1.9",
5152
"pre-commit>=4.0.1,<4.6.0",
52-
"pandas>=2.0"
53+
"pandas>=2.0",
54+
"polars>=1.0"
5355
]
5456

5557
[tool.ruff.lint]

uv.lock

Lines changed: 32 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)