Skip to content

Commit fa1a32a

Browse files
authored
Fix indents (#1493)
1 parent 5eb7d62 commit fa1a32a

1 file changed

Lines changed: 32 additions & 18 deletions

File tree

website/docs/Use-Cases/Task-Oriented-AutoML.md

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ To tune a custom estimator that is not built-in, you need to:
242242

243243
```python
244244
from flaml.automl.model import SKLearnEstimator
245+
245246
# SKLearnEstimator is derived from BaseEstimator
246247
import rgf
247248

@@ -250,31 +251,44 @@ class MyRegularizedGreedyForest(SKLearnEstimator):
250251
def __init__(self, task="binary", **config):
251252
super().__init__(task, **config)
252253

253-
if task in CLASSIFICATION:
254-
from rgf.sklearn import RGFClassifier
254+
if isinstance(task, str):
255+
from flaml.automl.task.factory import task_factory
256+
257+
task = task_factory(task)
258+
259+
if task.is_classification():
260+
from rgf.sklearn import RGFClassifier
255261

256-
self.estimator_class = RGFClassifier
262+
self.estimator_class = RGFClassifier
257263
else:
258-
from rgf.sklearn import RGFRegressor
264+
from rgf.sklearn import RGFRegressor
259265

260-
self.estimator_class = RGFRegressor
266+
self.estimator_class = RGFRegressor
261267

262268
@classmethod
263269
def search_space(cls, data_size, task):
264270
space = {
265-
"max_leaf": {
266-
"domain": tune.lograndint(lower=4, upper=data_size),
267-
"low_cost_init_value": 4,
268-
},
269-
"n_iter": {
270-
"domain": tune.lograndint(lower=1, upper=data_size),
271-
"low_cost_init_value": 1,
272-
},
273-
"learning_rate": {"domain": tune.loguniform(lower=0.01, upper=20.0)},
274-
"min_samples_leaf": {
275-
"domain": tune.lograndint(lower=1, upper=20),
276-
"init_value": 20,
277-
},
271+
"max_leaf": {
272+
"domain": tune.lograndint(lower=4, upper=data_size[0]),
273+
"init_value": 4,
274+
},
275+
"n_iter": {
276+
"domain": tune.lograndint(lower=1, upper=data_size[0]),
277+
"init_value": 1,
278+
},
279+
"n_tree_search": {
280+
"domain": tune.lograndint(lower=1, upper=32768),
281+
"init_value": 1,
282+
},
283+
"opt_interval": {
284+
"domain": tune.lograndint(lower=1, upper=10000),
285+
"init_value": 100,
286+
},
287+
"learning_rate": {"domain": tune.loguniform(lower=0.01, upper=20.0)},
288+
"min_samples_leaf": {
289+
"domain": tune.lograndint(lower=1, upper=20),
290+
"init_value": 20,
291+
},
278292
}
279293
return space
280294
```

0 commit comments

Comments
 (0)