2828pytestmark = [pytest .mark .skipif (skip_spark , reason = "Spark is not installed. Skip all spark tests." ), pytest .mark .spark ]
2929
3030
31- def test_parallel_xgboost (hpo_method = None , data_size = 1000 ):
31+ def test_parallel_xgboost_and_pickle (hpo_method = None , data_size = 1000 ):
3232 automl_experiment = AutoML ()
3333 automl_settings = {
34- "time_budget" : 10 ,
34+ "time_budget" : 30 ,
3535 "metric" : "ap" ,
3636 "task" : "classification" ,
3737 "log_file_name" : "test/sparse_classification.log" ,
@@ -53,15 +53,27 @@ def test_parallel_xgboost(hpo_method=None, data_size=1000):
5353 print (automl_experiment .best_iteration )
5454 print (automl_experiment .best_estimator )
5555
56+ # test pickle and load_pickle, should work for prediction
57+ automl_experiment .pickle ("automl_xgboost_spark.pkl" )
58+ automl_loaded = AutoML ().load_pickle ("automl_xgboost_spark.pkl" )
59+ assert automl_loaded .best_estimator == automl_experiment .best_estimator
60+ assert automl_loaded .best_loss == automl_experiment .best_loss
61+ automl_loaded .predict (X_train )
62+
63+ import shutil
64+
65+ shutil .rmtree ("automl_xgboost_spark.pkl" , ignore_errors = True )
66+ shutil .rmtree ("automl_xgboost_spark.pkl.flaml_artifacts" , ignore_errors = True )
67+
5668
5769def test_parallel_xgboost_others ():
5870 # use random search as the hpo_method
59- test_parallel_xgboost (hpo_method = "random" )
71+ test_parallel_xgboost_and_pickle (hpo_method = "random" )
6072
6173
6274@pytest .mark .skip (reason = "currently not supporting too large data, will support spark dataframe in the future" )
6375def test_large_dataset ():
64- test_parallel_xgboost (data_size = 90000000 )
76+ test_parallel_xgboost_and_pickle (data_size = 90000000 )
6577
6678
6779@pytest .mark .skipif (
@@ -95,10 +107,10 @@ def test_custom_learner(data_size=1000):
95107
96108
97109if __name__ == "__main__" :
98- test_parallel_xgboost ()
99- test_parallel_xgboost_others ()
100- # test_large_dataset()
101- if skip_my_learner :
102- print ("please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file" )
103- else :
104- test_custom_learner ()
110+ test_parallel_xgboost_and_pickle ()
111+ # test_parallel_xgboost_others()
112+ # # test_large_dataset()
113+ # if skip_my_learner:
114+ # print("please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file")
115+ # else:
116+ # test_custom_learner()
0 commit comments