@@ -87,6 +87,7 @@ def test_plpr_approach_x_dim(approach, time_type):
8787 static_panel = True ,
8888 )
8989 dml_plpr = dml .DoubleMLPLPR (obj_dml_data , ml_l , ml_m , approach = approach )
90+ dml_plpr .fit ()
9091 if approach == "wg_approx" :
9192 assert len (dml_plpr ._dml_data .x_cols ) == dim_x
9293 else :
@@ -106,6 +107,7 @@ def test_plpr_approach_d_mean(approach, time_type):
106107 static_panel = True ,
107108 )
108109 dml_plpr = dml .DoubleMLPLPR (obj_dml_data , ml_l , ml_m , approach = approach )
110+ dml_plpr .fit ()
109111 if approach in ["cre_general" , "cre_normal" ]:
110112 assert dml_plpr .d_mean is not None
111113 else :
@@ -145,6 +147,7 @@ def test_plpr_fd_exact_unbalanced(time_type):
145147 )
146148 with pytest .warns (UserWarning , match = msg_warn ):
147149 obj_plpr = dml .DoubleMLPLPR (obj_dml_data , ml_l , ml_m , approach = "fd_exact" , n_folds = 2 )
150+ obj_plpr .fit ()
148151 # 4 rows after fd transformation as id 3 has no possible first difference
149152 assert obj_plpr .data_transform .data .shape [0 ] == 4
150153
@@ -168,6 +171,7 @@ def test_plpr_one_id(approach, time_type):
168171 )
169172 with pytest .warns (UserWarning , match = msg_warn ):
170173 obj_plpr = dml .DoubleMLPLPR (obj_dml_data , ml_l , ml_m , approach = approach , n_folds = 2 )
174+ obj_plpr .fit ()
171175 # 2 rows after fd transformation, 4 rows else
172176 if approach == "fd_exact" :
173177 assert obj_plpr .data_transform .data .shape [0 ] == 2
@@ -196,6 +200,7 @@ def test_plpr_fd_exact_one_id_unbalanced(time_type):
196200 # capture warnings
197201 with pytest .warns (UserWarning ) as record :
198202 obj_plpr = dml .DoubleMLPLPR (obj_dml_data , ml_l , ml_m , approach = "fd_exact" , n_folds = 2 )
203+ obj_plpr .fit ()
199204 # assert two warnings were raised and content
200205 assert len (record ) == 2
201206 assert msg_warn_one_id in str (record [0 ].message )
@@ -215,6 +220,7 @@ def test_plpr_time_cre_transformation(cre_approach, data_time_type):
215220 static_panel = True ,
216221 )
217222 dml_cre = dml .DoubleMLPLPR (obj_dml_data , ml_l , ml_m , approach = cre_approach , n_folds = 2 )
223+ dml_cre .fit ()
218224 assert dml_cre .transform_cols ["y_col" ] == "y"
219225 assert dml_cre .transform_cols ["d_cols" ] == ["d" ]
220226 assert dml_cre .transform_cols ["x_cols" ] == ["x1" , "x2" , "x1_mean" , "x2_mean" ]
0 commit comments