Skip to content

Commit dbe2be2

Browse files
committed
Add LightGBM Experiment integration
1 parent 7e049cb commit dbe2be2

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

src/hyperactive/experiment/integrations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# copyright: hyperactive developers, MIT License (see LICENSE file)
33

44
from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment
5+
from hyperactive.experiment.integrations.lightgbm_experiment import LightGBMExperiment
56
from hyperactive.experiment.integrations.skpro_probareg import (
67
SkproProbaRegExperiment,
78
)
@@ -21,4 +22,5 @@
2122
"SktimeClassificationExperiment",
2223
"SktimeForecastingExperiment",
2324
"TorchExperiment",
25+
"LightGBMExperiment",
2426
]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Experiment adapter for LightGBM cross-validation experiments."""
2+
3+
# copyright: hyperactive developers, MIT License (see LICENSE file)
4+
5+
from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment
6+
7+
8+
class LightGBMExperiment(SklearnCvExperiment):
9+
"""Experiment adapter for LightGBM cross-validation experiments.
10+
11+
Thin wrapper around SklearnCvExperiment for LightGBM estimators.
12+
13+
LightGBM estimators follow the sklearn API, so this class does not
14+
add new functionality beyond SklearnCvExperiment. It exists for
15+
discoverability and explicit LightGBM support.
16+
"""
17+
18+
_tags = {
19+
"python_dependencies": "lightgbm",
20+
}
21+
22+
@classmethod
23+
def get_test_params(cls, parameter_set="default"):
24+
"""Return testing parameter settings for the estimator."""
25+
from skbase.utils.dependencies import _check_soft_dependencies
26+
27+
if not _check_soft_dependencies("lightgbm", severity="none"):
28+
return []
29+
30+
from sklearn.datasets import load_iris, load_diabetes
31+
from lightgbm import LGBMClassifier, LGBMRegressor
32+
33+
# Classification test case
34+
X, y = load_iris(return_X_y=True)
35+
params0 = {
36+
"estimator": LGBMClassifier(n_estimators=10),
37+
"X": X,
38+
"y": y,
39+
"cv": 2,
40+
}
41+
42+
# Regression test case
43+
X, y = load_diabetes(return_X_y=True)
44+
params1 = {
45+
"estimator": LGBMRegressor(n_estimators=10),
46+
"X": X,
47+
"y": y,
48+
"cv": 2,
49+
}
50+
51+
return [params0, params1]
52+
53+
@classmethod
54+
def _get_score_params(cls):
55+
"""Return parameter settings for score/evaluate tests."""
56+
from skbase.utils.dependencies import _check_soft_dependencies
57+
58+
if not _check_soft_dependencies("lightgbm", severity="none"):
59+
return []
60+
61+
val0 = {"n_estimators": 5, "max_depth": 2}
62+
val1 = {"n_estimators": 5, "max_depth": 2}
63+
64+
return [val0, val1]

0 commit comments

Comments
 (0)