Skip to content

Commit b6ee9fb

Browse files
committed
gridsearch
1 parent d9fcea3 commit b6ee9fb

8 files changed

Lines changed: 440 additions & 156 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ install-no-extras-for-test:
8787
python -m pip install .[test]
8888

8989
install-all-extras-for-test:
90-
python -m pip install .[all_extras,test]
90+
python -m pip install .[all_extras,test,test_parallel_backends]
9191

9292
install-editable:
9393
pip install -e .

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ test = [
6060
"pytest-cov",
6161
"pathos",
6262
]
63+
test_parallel_backends = [
64+
"dask",
65+
"joblib",
66+
"ray",
67+
]
6368
all_extras = [
6469
"hyperactive[integrations]",
6570
]

src/hyperactive/opt/gridsearch/_sk.py

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sklearn.model_selection import ParameterGrid
99

1010
from hyperactive.base import BaseOptimizer
11+
from hyperactive.utils.parallel import parallelize
1112

1213

1314
class GridSearchSk(BaseOptimizer):
@@ -18,8 +19,50 @@ class GridSearchSk(BaseOptimizer):
1819
param_grid : dict[str, list]
1920
The search space to explore. A dictionary with parameter
2021
names as keys and a numpy array as values.
22+
2123
error_score : float, default=np.nan
2224
The score to assign if an error occurs during the evaluation of a parameter set.
25+
26+
backend : {"dask", "loky", "multiprocessing", "threading","ray"}, by default "None".
27+
Runs parallel evaluate if specified and ``strategy`` is set as "refit".
28+
29+
- "None": executes loop sequentally, simple list comprehension
30+
- "loky", "multiprocessing" and "threading": uses ``joblib.Parallel`` loops
31+
- "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
32+
- "dask": uses ``dask``, requires ``dask`` package in environment
33+
- "ray": uses ``ray``, requires ``ray`` package in environment
34+
35+
Recommendation: Use "dask" or "loky" for parallel evaluate.
36+
"threading" is unlikely to see speed ups due to the GIL and the serialization
37+
backend (``cloudpickle``) for "dask" and "loky" is generally more robust
38+
than the standard ``pickle`` library used in "multiprocessing".
39+
40+
backend_params : dict, optional
41+
additional parameters passed to the backend as config.
42+
Directly passed to ``utils.parallel.parallelize``.
43+
Valid keys depend on the value of ``backend``:
44+
45+
- "None": no additional parameters, ``backend_params`` is ignored
46+
- "loky", "multiprocessing" and "threading": default ``joblib`` backends
47+
any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
48+
with the exception of ``backend`` which is directly controlled by ``backend``.
49+
If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
50+
will default to ``joblib`` defaults.
51+
- "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
52+
any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
53+
``backend`` must be passed as a key of ``backend_params`` in this case.
54+
If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
55+
will default to ``joblib`` defaults.
56+
- "dask": any valid keys for ``dask.compute`` can be passed, e.g., ``scheduler``
57+
58+
- "ray": The following keys can be passed:
59+
60+
- "ray_remote_args": dictionary of valid keys for ``ray.init``
61+
- "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
62+
down after parallelization.
63+
- "logger_name": str, default="ray"; name of the logger to use.
64+
- "mute_warnings": bool, default=False; if True, suppresses warnings
65+
2366
experiment : BaseExperiment, optional
2467
The experiment to optimize parameters for.
2568
Optional, can be passed later via ``set_params``.
@@ -60,11 +103,15 @@ def __init__(
60103
self,
61104
param_grid=None,
62105
error_score=np.nan,
106+
backend="None",
107+
backend_params=None,
63108
experiment=None,
64109
):
65110
self.experiment = experiment
66111
self.param_grid = param_grid
67112
self.error_score = error_score
113+
self.backend = backend
114+
self.backend_params = backend_params
68115

69116
super().__init__()
70117

@@ -97,14 +144,18 @@ def _run(self, experiment, param_grid, error_score):
97144
self._check_param_grid(param_grid)
98145
candidate_params = list(ParameterGrid(param_grid))
99146

100-
scores = []
101-
for candidate_param in candidate_params:
102-
try:
103-
score = experiment(**candidate_param)
104-
except Exception: # noqa: B904
105-
# Catch all exceptions and assign error_score
106-
score = error_score
107-
scores.append(score)
147+
meta = {
148+
"experiment": experiment,
149+
"error_score": error_score,
150+
}
151+
152+
scores = parallelize(
153+
fun=_score_params,
154+
iter=candidate_params,
155+
meta=meta,
156+
backend=self.backend,
157+
backend_params=self.backend_params,
158+
)
108159

109160
best_index = np.argmin(scores)
110161
best_params = candidate_params[best_index]
@@ -170,5 +221,30 @@ def get_test_params(cls, parameter_set="default"):
170221
"experiment": ackley_exp,
171222
"param_grid": param_grid,
172223
}
173-
174-
return [params_sklearn, params_ackley]
224+
225+
params = [params_sklearn, params_ackley]
226+
227+
from hyperactive.utils.parallel import _get_parallel_test_fixtures
228+
229+
parallel_fixtures = _get_parallel_test_fixtures()
230+
231+
for k, v in parallel_fixtures.items():
232+
new_ackley = params_ackley.copy()
233+
new_ackley["backend"] = k
234+
new_ackley["backend_params"] = v
235+
params.append(new_ackley)
236+
237+
return params
238+
239+
240+
def _score_params(params, meta):
241+
"""Function to score parameters, used in parallelization."""
242+
meta = meta.copy()
243+
experiment = meta["experiment"]
244+
error_score = meta["error_score"]
245+
246+
try:
247+
return experiment(**params)
248+
except Exception: # noqa: B904
249+
# Catch all exceptions and assign error_score
250+
return error_score

src/hyperactive/utils/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1 @@
11
"""Utility functionality."""
2-
3-
from hyperactive.utils.estimator_checks import check_estimator
4-
5-
__all__ = [
6-
"check_estimator",
7-
]

src/hyperactive/utils/estimator_checks.py

Lines changed: 0 additions & 139 deletions
This file was deleted.

0 commit comments

Comments
 (0)