-
Notifications
You must be signed in to change notification settings - Fork 75
Expand file tree
/
Copy pathlightgbm_experiment.py
More file actions
64 lines (47 loc) · 1.95 KB
/
lightgbm_experiment.py
File metadata and controls
64 lines (47 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Experiment adapter for LightGBM cross-validation experiments."""
# copyright: hyperactive developers, MIT License (see LICENSE file)
from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment
class LightGBMExperiment(SklearnCvExperiment):
"""Experiment adapter for LightGBM cross-validation experiments.
Thin wrapper around SklearnCvExperiment for LightGBM estimators.
LightGBM estimators follow the sklearn API, so this class does not
add new functionality beyond SklearnCvExperiment. It exists for
discoverability and explicit LightGBM support.
"""
_tags = {
"python_dependencies": "lightgbm",
}
@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator."""
from skbase.utils.dependencies import _check_soft_dependencies
if not _check_soft_dependencies("lightgbm", severity="none"):
return []
from sklearn.datasets import load_iris, load_diabetes
from lightgbm import LGBMClassifier, LGBMRegressor
# Classification test case
X, y = load_iris(return_X_y=True)
params0 = {
"estimator": LGBMClassifier(n_estimators=10),
"X": X,
"y": y,
"cv": 2,
}
# Regression test case
X, y = load_diabetes(return_X_y=True)
params1 = {
"estimator": LGBMRegressor(n_estimators=10),
"X": X,
"y": y,
"cv": 2,
}
return [params0, params1]
@classmethod
def _get_score_params(cls):
"""Return parameter settings for score/evaluate tests."""
from skbase.utils.dependencies import _check_soft_dependencies
if not _check_soft_dependencies("lightgbm", severity="none"):
return []
val0 = {"n_estimators": 5, "max_depth": 2}
val1 = {"n_estimators": 5, "max_depth": 2}
return [val0, val1]