Skip to content

Commit ceb8d40

Browse files
authored
Merge pull request #392 from DoubleML/jd-fix-plpr-bug
fix d_mean calculation in PLPR, update tests
2 parents 2d666c5 + 32db11a commit ceb8d40

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

doubleml/plm/plpr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,9 @@ def _transform_data(self):
335335

336336
def _set_d_mean(self):
337337
if self._approach in ["cre_general", "cre_normal"]:
338-
data = self._original_dml_data.data
339-
d_cols = self._original_dml_data.d_cols
340-
id_col = self._original_dml_data.id_col
338+
data = self._dml_data.data
339+
d_cols = self._dml_data.d_cols
340+
id_col = self._dml_data.id_col
341341
help_d_mean = data.loc[:, [id_col] + d_cols]
342342
d_mean = help_d_mean.groupby(id_col).transform("mean").values
343343
self._d_mean = d_mean

doubleml/plm/tests/test_plpr_transformations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_plpr_approach_x_dim(approach, time_type):
8787
static_panel=True,
8888
)
8989
dml_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach)
90+
dml_plpr.fit()
9091
if approach == "wg_approx":
9192
assert len(dml_plpr._dml_data.x_cols) == dim_x
9293
else:
@@ -106,6 +107,7 @@ def test_plpr_approach_d_mean(approach, time_type):
106107
static_panel=True,
107108
)
108109
dml_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach)
110+
dml_plpr.fit()
109111
if approach in ["cre_general", "cre_normal"]:
110112
assert dml_plpr.d_mean is not None
111113
else:
@@ -145,6 +147,7 @@ def test_plpr_fd_exact_unbalanced(time_type):
145147
)
146148
with pytest.warns(UserWarning, match=msg_warn):
147149
obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach="fd_exact", n_folds=2)
150+
obj_plpr.fit()
148151
# 4 rows after fd transformation as id 3 has no possible first difference
149152
assert obj_plpr.data_transform.data.shape[0] == 4
150153

@@ -168,6 +171,7 @@ def test_plpr_one_id(approach, time_type):
168171
)
169172
with pytest.warns(UserWarning, match=msg_warn):
170173
obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach, n_folds=2)
174+
obj_plpr.fit()
171175
# 2 rows after fd transformation, 4 rows else
172176
if approach == "fd_exact":
173177
assert obj_plpr.data_transform.data.shape[0] == 2
@@ -196,6 +200,7 @@ def test_plpr_fd_exact_one_id_unbalanced(time_type):
196200
# capture warnings
197201
with pytest.warns(UserWarning) as record:
198202
obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach="fd_exact", n_folds=2)
203+
obj_plpr.fit()
199204
# assert two warnings were raised and content
200205
assert len(record) == 2
201206
assert msg_warn_one_id in str(record[0].message)
@@ -215,6 +220,7 @@ def test_plpr_time_cre_transformation(cre_approach, data_time_type):
215220
static_panel=True,
216221
)
217222
dml_cre = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=cre_approach, n_folds=2)
223+
dml_cre.fit()
218224
assert dml_cre.transform_cols["y_col"] == "y"
219225
assert dml_cre.transform_cols["d_cols"] == ["d"]
220226
assert dml_cre.transform_cols["x_cols"] == ["x1", "x2", "x1_mean", "x2_mean"]

0 commit comments

Comments
 (0)