Skip to content

Commit 163cdd9

Browse files
committed
Allow fit methods to accept pd.Series and pd.DataFrame (#62)
1 parent 32e0873 commit 163cdd9

5 files changed

Lines changed: 191 additions & 1 deletion

File tree

dte_adj/local.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
SimpleStratifiedDistributionEstimator,
55
AdjustedStratifiedDistributionEstimator,
66
)
7-
from dte_adj.util import compute_ldte, compute_lpte
7+
from dte_adj.util import compute_ldte, compute_lpte, _convert_to_ndarray
88

99

1010
class SimpleLocalDistributionEstimator(SimpleStratifiedDistributionEstimator):
@@ -47,6 +47,7 @@ def fit(
4747
Returns:
4848
SimpleLocalDistributionEstimator: The fitted estimator.
4949
"""
50+
treatment_indicator = _convert_to_ndarray(treatment_indicator)
5051
super().fit(covariates, treatment_arms, outcomes, strata)
5152
self.treatment_indicator = treatment_indicator
5253

@@ -215,6 +216,7 @@ def fit(
215216
Returns:
216217
AdjustedLocalDistributionEstimator: The fitted estimator.
217218
"""
219+
treatment_indicator = _convert_to_ndarray(treatment_indicator)
218220
super().fit(covariates, treatment_arms, outcomes, strata)
219221
self.treatment_indicator = treatment_indicator
220222

dte_adj/simple.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
SimpleStratifiedDistributionEstimator,
44
AdjustedStratifiedDistributionEstimator,
55
)
6+
from dte_adj.util import _convert_to_ndarray
67

78

89
class SimpleDistributionEstimator(SimpleStratifiedDistributionEstimator):
@@ -58,6 +59,10 @@ def fit(
5859
Returns:
5960
SimpleDistributionEstimator: The fitted estimator.
6061
"""
62+
covariates = _convert_to_ndarray(covariates)
63+
treatment_arms = _convert_to_ndarray(treatment_arms)
64+
outcomes = _convert_to_ndarray(outcomes)
65+
6166
if covariates.shape[0] != treatment_arms.shape[0]:
6267
raise ValueError("The shape of covariates and treatment_arm should be same")
6368

@@ -118,6 +123,10 @@ def fit(
118123
Returns:
119124
AdjustedDistributionEstimator: The fitted estimator.
120125
"""
126+
covariates = _convert_to_ndarray(covariates)
127+
treatment_arms = _convert_to_ndarray(treatment_arms)
128+
outcomes = _convert_to_ndarray(outcomes)
129+
121130
if covariates.shape[0] != treatment_arms.shape[0]:
122131
raise ValueError("The shape of covariates and treatment_arm should be same")
123132

dte_adj/stratified.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Tuple, Any
33
from copy import deepcopy
44
from dte_adj.base import DistributionEstimatorBase
5+
from dte_adj.util import _convert_to_ndarray
56

67

78
class SimpleStratifiedDistributionEstimator(DistributionEstimatorBase):
@@ -25,6 +26,11 @@ def fit(
2526
Returns:
2627
DistributionEstimatorBase: The fitted estimator.
2728
"""
29+
covariates = _convert_to_ndarray(covariates)
30+
treatment_arms = _convert_to_ndarray(treatment_arms)
31+
outcomes = _convert_to_ndarray(outcomes)
32+
strata = _convert_to_ndarray(strata)
33+
2834
if covariates.shape[0] != treatment_arms.shape[0]:
2935
raise ValueError("The shape of covariates and treatment_arm should be same")
3036

@@ -184,6 +190,11 @@ def fit(
184190
Returns:
185191
DistributionEstimatorBase: The fitted estimator.
186192
"""
193+
covariates = _convert_to_ndarray(covariates)
194+
treatment_arms = _convert_to_ndarray(treatment_arms)
195+
outcomes = _convert_to_ndarray(outcomes)
196+
strata = _convert_to_ndarray(strata)
197+
187198
if covariates.shape[0] != treatment_arms.shape[0]:
188199
raise ValueError("The shape of covariates and treatment_arm should be same")
189200

dte_adj/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
)
1010

1111

12+
def _convert_to_ndarray(data: object) -> np.ndarray:
13+
"""Convert pd.Series or pd.DataFrame to np.ndarray if needed."""
14+
if hasattr(data, "to_numpy"):
15+
return data.to_numpy()
16+
return data
17+
18+
1219
def compute_confidence_intervals(
1320
vec_y: np.ndarray,
1421
vec_d: np.ndarray,

tests/test_pandas_input.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import unittest
2+
import numpy as np
3+
import pandas as pd
4+
from unittest.mock import MagicMock
5+
from sklearn.linear_model import LogisticRegression
6+
from dte_adj import (
7+
SimpleDistributionEstimator,
8+
AdjustedDistributionEstimator,
9+
SimpleStratifiedDistributionEstimator,
10+
AdjustedStratifiedDistributionEstimator,
11+
SimpleLocalDistributionEstimator,
12+
AdjustedLocalDistributionEstimator,
13+
)
14+
15+
16+
class TestPandasInputSimple(unittest.TestCase):
17+
"""Test that Simple/Adjusted DistributionEstimator accept pandas inputs."""
18+
19+
def setUp(self):
20+
np.random.seed(42)
21+
n = 20
22+
self.covariates_df = pd.DataFrame(np.zeros((n, 5)), columns=[f"x{i}" for i in range(5)])
23+
self.treatment_arms_series = pd.Series(np.hstack([np.zeros(10), np.ones(10)]))
24+
self.outcomes_series = pd.Series(np.arange(n, dtype=float))
25+
26+
def test_simple_estimator_with_dataframe_and_series(self):
27+
estimator = SimpleDistributionEstimator()
28+
result = estimator.fit(
29+
self.covariates_df, self.treatment_arms_series, self.outcomes_series
30+
)
31+
32+
self.assertIsInstance(result.covariates, np.ndarray)
33+
self.assertIsInstance(result.treatment_arms, np.ndarray)
34+
self.assertIsInstance(result.outcomes, np.ndarray)
35+
36+
def test_simple_estimator_predict_after_pandas_fit(self):
37+
estimator = SimpleDistributionEstimator()
38+
estimator.fit(self.covariates_df, self.treatment_arms_series, self.outcomes_series)
39+
40+
output = estimator.predict(0, np.array([3, 6]))
41+
expected = np.array([0.4, 0.7])
42+
np.testing.assert_array_almost_equal(output, expected, decimal=2)
43+
44+
def test_adjusted_estimator_with_dataframe_and_series(self):
45+
base_model = MagicMock()
46+
base_model.predict_proba.side_effect = lambda x, y: x
47+
estimator = AdjustedDistributionEstimator(base_model, folds=2)
48+
result = estimator.fit(
49+
self.covariates_df, self.treatment_arms_series, self.outcomes_series
50+
)
51+
52+
self.assertIsInstance(result.covariates, np.ndarray)
53+
self.assertIsInstance(result.treatment_arms, np.ndarray)
54+
self.assertIsInstance(result.outcomes, np.ndarray)
55+
56+
57+
class TestPandasInputStratified(unittest.TestCase):
58+
"""Test that Stratified estimators accept pandas inputs."""
59+
60+
def setUp(self):
61+
np.random.seed(42)
62+
n = 100
63+
self.covariates_df = pd.DataFrame(
64+
np.random.randn(n, 5), columns=[f"x{i}" for i in range(5)]
65+
)
66+
self.treatment_arms_series = pd.Series(np.random.choice([0, 1], size=n))
67+
self.outcomes_series = pd.Series(np.random.randn(n))
68+
self.strata_series = pd.Series(np.random.choice([0, 1, 2], size=n))
69+
70+
def test_simple_stratified_with_pandas(self):
71+
estimator = SimpleStratifiedDistributionEstimator()
72+
result = estimator.fit(
73+
self.covariates_df,
74+
self.treatment_arms_series,
75+
self.outcomes_series,
76+
self.strata_series,
77+
)
78+
79+
self.assertIsInstance(result.covariates, np.ndarray)
80+
self.assertIsInstance(result.treatment_arms, np.ndarray)
81+
self.assertIsInstance(result.outcomes, np.ndarray)
82+
self.assertIsInstance(result.strata, np.ndarray)
83+
84+
def test_adjusted_stratified_with_pandas(self):
85+
base_model = LogisticRegression(random_state=42)
86+
estimator = AdjustedStratifiedDistributionEstimator(base_model, folds=2)
87+
result = estimator.fit(
88+
self.covariates_df,
89+
self.treatment_arms_series,
90+
self.outcomes_series,
91+
self.strata_series,
92+
)
93+
94+
self.assertIsInstance(result.covariates, np.ndarray)
95+
self.assertIsInstance(result.treatment_arms, np.ndarray)
96+
self.assertIsInstance(result.outcomes, np.ndarray)
97+
self.assertIsInstance(result.strata, np.ndarray)
98+
99+
100+
class TestPandasInputLocal(unittest.TestCase):
101+
"""Test that Local estimators accept pandas inputs."""
102+
103+
def setUp(self):
104+
np.random.seed(42)
105+
n = 100
106+
self.covariates_df = pd.DataFrame(
107+
np.random.randn(n, 3), columns=[f"x{i}" for i in range(3)]
108+
)
109+
self.treatment_arms_series = pd.Series(np.random.choice([0, 1], size=n))
110+
self.treatment_indicator_series = pd.Series(np.random.choice([0, 1], size=n))
111+
self.outcomes_series = pd.Series(np.random.randn(n))
112+
self.strata_series = pd.Series(np.random.choice([0, 1], size=n))
113+
114+
def test_simple_local_with_pandas(self):
115+
estimator = SimpleLocalDistributionEstimator()
116+
result = estimator.fit(
117+
self.covariates_df,
118+
self.treatment_arms_series,
119+
self.treatment_indicator_series,
120+
self.outcomes_series,
121+
self.strata_series,
122+
)
123+
124+
self.assertIsInstance(result.covariates, np.ndarray)
125+
self.assertIsInstance(result.treatment_arms, np.ndarray)
126+
self.assertIsInstance(result.treatment_indicator, np.ndarray)
127+
self.assertIsInstance(result.outcomes, np.ndarray)
128+
self.assertIsInstance(result.strata, np.ndarray)
129+
130+
def test_adjusted_local_with_pandas(self):
131+
base_model = LogisticRegression(random_state=42)
132+
estimator = AdjustedLocalDistributionEstimator(base_model=base_model)
133+
result = estimator.fit(
134+
self.covariates_df,
135+
self.treatment_arms_series,
136+
self.treatment_indicator_series,
137+
self.outcomes_series,
138+
self.strata_series,
139+
)
140+
141+
self.assertIsInstance(result.covariates, np.ndarray)
142+
self.assertIsInstance(result.treatment_arms, np.ndarray)
143+
self.assertIsInstance(result.treatment_indicator, np.ndarray)
144+
self.assertIsInstance(result.outcomes, np.ndarray)
145+
self.assertIsInstance(result.strata, np.ndarray)
146+
147+
148+
class TestNumpyInputStillWorks(unittest.TestCase):
149+
"""Verify that np.ndarray inputs continue to work as before."""
150+
151+
def test_simple_estimator_with_numpy(self):
152+
estimator = SimpleDistributionEstimator()
153+
covariates = np.zeros((20, 5))
154+
treatment_arms = np.hstack([np.zeros(10), np.ones(10)])
155+
outcomes = np.arange(20, dtype=float)
156+
157+
result = estimator.fit(covariates, treatment_arms, outcomes)
158+
159+
self.assertIsInstance(result.covariates, np.ndarray)
160+
self.assertIsInstance(result.treatment_arms, np.ndarray)
161+
self.assertIsInstance(result.outcomes, np.ndarray)

0 commit comments

Comments
 (0)