@@ -242,6 +242,7 @@ To tune a custom estimator that is not built-in, you need to:
242242
243243``` python
244244from flaml.automl.model import SKLearnEstimator
245+
245246# SKLearnEstimator is derived from BaseEstimator
246247import rgf
247248
@@ -250,31 +251,44 @@ class MyRegularizedGreedyForest(SKLearnEstimator):
250251 def __init__ (self , task = " binary" , ** config ):
251252 super ().__init__ (task, ** config)
252253
253- if task in CLASSIFICATION :
254- from rgf.sklearn import RGFClassifier
254+ if isinstance (task, str ):
255+ from flaml.automl.task.factory import task_factory
256+
257+ task = task_factory(task)
258+
259+ if task.is_classification():
260+ from rgf.sklearn import RGFClassifier
255261
256- self .estimator_class = RGFClassifier
262+ self .estimator_class = RGFClassifier
257263 else :
258- from rgf.sklearn import RGFRegressor
264+ from rgf.sklearn import RGFRegressor
259265
260- self .estimator_class = RGFRegressor
266+ self .estimator_class = RGFRegressor
261267
262268 @ classmethod
263269 def search_space (cls , data_size , task ):
264270 space = {
265- " max_leaf" : {
266- " domain" : tune.lograndint(lower = 4 , upper = data_size),
267- " low_cost_init_value" : 4 ,
268- },
269- " n_iter" : {
270- " domain" : tune.lograndint(lower = 1 , upper = data_size),
271- " low_cost_init_value" : 1 ,
272- },
273- " learning_rate" : {" domain" : tune.loguniform(lower = 0.01 , upper = 20.0 )},
274- " min_samples_leaf" : {
275- " domain" : tune.lograndint(lower = 1 , upper = 20 ),
276- " init_value" : 20 ,
277- },
271+ " max_leaf" : {
272+ " domain" : tune.lograndint(lower = 4 , upper = data_size[0 ]),
273+ " init_value" : 4 ,
274+ },
275+ " n_iter" : {
276+ " domain" : tune.lograndint(lower = 1 , upper = data_size[0 ]),
277+ " init_value" : 1 ,
278+ },
279+ " n_tree_search" : {
280+ " domain" : tune.lograndint(lower = 1 , upper = 32768 ),
281+ " init_value" : 1 ,
282+ },
283+ " opt_interval" : {
284+ " domain" : tune.lograndint(lower = 1 , upper = 10000 ),
285+ " init_value" : 100 ,
286+ },
287+ " learning_rate" : {" domain" : tune.loguniform(lower = 0.01 , upper = 20.0 )},
288+ " min_samples_leaf" : {
289+ " domain" : tune.lograndint(lower = 1 , upper = 20 ),
290+ " init_value" : 20 ,
291+ },
278292 }
279293 return space
280294```
0 commit comments