Skip to content

Commit f663313

Browse files
authored
[ENH] parallelization backends for grid and random search (#150)
This PR introduces parallelization backends for `GridSearch` and `RandomSearch`. Both estimators now allow to specify `backend` and `backend_params` in the constructor, which allows selection of a parallelization backend and configuration parameters for it - the default being `"None"`, i.e., plain loop. This uses the `parallel` utilities also used in `sktime` and `skpro`, with a mid-term plan to move these to `scikit-base`.
1 parent 3a14a58 commit f663313

8 files changed

Lines changed: 523 additions & 20 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ install-no-extras-for-test:
9090
python -m pip install .[test]
9191

9292
install-all-extras-for-test:
93-
python -m pip install .[all_extras,sktime-integration,test]
93+
python -m pip install .[all_extras,test,test_parallel_backends,sktime-integration]
9494

9595
install-editable:
9696
pip install -e .

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ test = [
6565
"torch",
6666
"tf_keras",
6767
]
68+
test_parallel_backends = [
69+
"dask",
70+
"joblib",
71+
'ray >=2.40.0; python_version < "3.13"',
72+
]
6873
all_extras = [
6974
"hyperactive[integrations]",
7075
"optuna<5",

src/hyperactive/opt/_common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Common functions used by multiple optimizers."""
2+
3+
__all__ = ["_score_params"]
4+
5+
6+
def _score_params(params, meta):
7+
"""Score parameters, used in parallelization."""
8+
meta = meta.copy()
9+
experiment = meta["experiment"]
10+
error_score = meta["error_score"]
11+
12+
try:
13+
return experiment(**params)
14+
except Exception: # noqa: B904
15+
# Catch all exceptions and assign error_score
16+
return error_score

src/hyperactive/opt/gridsearch/_sk.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from sklearn.model_selection import ParameterGrid
88

99
from hyperactive.base import BaseOptimizer
10+
from hyperactive.opt._common import _score_params
11+
from hyperactive.utils.parallel import parallelize
1012

1113

1214
class GridSearchSk(BaseOptimizer):
@@ -17,8 +19,45 @@ class GridSearchSk(BaseOptimizer):
1719
param_grid : dict[str, list]
1820
The search space to explore. A dictionary with parameter
1921
names as keys and a numpy array as values.
22+
2023
error_score : float, default=np.nan
2124
The score to assign if an error occurs during the evaluation of a parameter set.
25+
26+
backend : {"dask", "loky", "multiprocessing", "threading", "ray"}, default = "None".
27+
Parallelization backend to use in the search process.
28+
29+
- "None": executes loop sequentally, simple list comprehension
30+
- "loky", "multiprocessing" and "threading": uses ``joblib.Parallel`` loops
31+
- "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
32+
- "dask": uses ``dask``, requires ``dask`` package in environment
33+
- "ray": uses ``ray``, requires ``ray`` package in environment
34+
35+
backend_params : dict, optional
36+
additional parameters passed to the backend as config.
37+
Directly passed to ``utils.parallel.parallelize``.
38+
Valid keys depend on the value of ``backend``:
39+
40+
- "None": no additional parameters, ``backend_params`` is ignored
41+
- "loky", "multiprocessing" and "threading": default ``joblib`` backends
42+
any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
43+
with the exception of ``backend`` which is directly controlled by ``backend``.
44+
If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
45+
will default to ``joblib`` defaults.
46+
- "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
47+
any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
48+
``backend`` must be passed as a key of ``backend_params`` in this case.
49+
If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
50+
will default to ``joblib`` defaults.
51+
- "dask": any valid keys for ``dask.compute`` can be passed, e.g., ``scheduler``
52+
53+
- "ray": The following keys can be passed:
54+
55+
- "ray_remote_args": dictionary of valid keys for ``ray.init``
56+
- "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
57+
down after parallelization.
58+
- "logger_name": str, default="ray"; name of the logger to use.
59+
- "mute_warnings": bool, default=False; if True, suppresses warnings
60+
2261
experiment : BaseExperiment, optional
2362
The experiment to optimize parameters for.
2463
Optional, can be passed later via ``set_params``.
@@ -53,17 +92,29 @@ class GridSearchSk(BaseOptimizer):
5392
5493
Best parameters can also be accessed via the attributes:
5594
>>> best_params = grid_search.best_params_
95+
96+
To parallelize the search, set the ``backend`` and ``backend_params``:
97+
>>> grid_search = GridSearch(
98+
... param_grid,
99+
... backend="joblib",
100+
... backend_params={"n_jobs": -1},
101+
... experiment=sklearn_exp,
102+
... )
56103
"""
57104

58105
def __init__(
59106
self,
60107
param_grid=None,
61108
error_score=np.nan,
109+
backend="None",
110+
backend_params=None,
62111
experiment=None,
63112
):
64113
self.experiment = experiment
65114
self.param_grid = param_grid
66115
self.error_score = error_score
116+
self.backend = backend
117+
self.backend_params = backend_params
67118

68119
super().__init__()
69120

@@ -91,19 +142,23 @@ def _check_param_grid(self, param_grid):
91142
"to be a non-empty sequence."
92143
)
93144

94-
def _solve(self, experiment, param_grid, error_score):
145+
def _solve(self, experiment, param_grid, error_score, backend, backend_params):
95146
"""Run the optimization search process."""
96147
self._check_param_grid(param_grid)
97148
candidate_params = list(ParameterGrid(param_grid))
98149

99-
scores = []
100-
for candidate_param in candidate_params:
101-
try:
102-
score = experiment(**candidate_param)
103-
except Exception: # noqa: B904
104-
# Catch all exceptions and assign error_score
105-
score = error_score
106-
scores.append(score)
150+
meta = {
151+
"experiment": experiment,
152+
"error_score": error_score,
153+
}
154+
155+
scores = parallelize(
156+
fun=_score_params,
157+
iter=candidate_params,
158+
meta=meta,
159+
backend=backend,
160+
backend_params=backend_params,
161+
)
107162

108163
best_index = np.argmin(scores)
109164
best_params = candidate_params[best_index]
@@ -170,4 +225,15 @@ def get_test_params(cls, parameter_set="default"):
170225
"param_grid": param_grid,
171226
}
172227

173-
return [params_sklearn, params_ackley]
228+
params = [params_sklearn, params_ackley]
229+
230+
from hyperactive.utils.parallel import _get_parallel_test_fixtures
231+
232+
parallel_fixtures = _get_parallel_test_fixtures()
233+
234+
for x in parallel_fixtures:
235+
new_ackley = params_ackley.copy()
236+
new_ackley.update(x)
237+
params.append(new_ackley)
238+
239+
return params

src/hyperactive/opt/random_search.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from sklearn.model_selection import ParameterSampler
99

1010
from hyperactive.base import BaseOptimizer
11+
from hyperactive.opt._common import _score_params
12+
from hyperactive.utils.parallel import parallelize
1113

1214

1315
class RandomSearchSk(BaseOptimizer):
@@ -18,12 +20,51 @@ class RandomSearchSk(BaseOptimizer):
1820
param_distributions : dict[str, list | scipy.stats.rv_frozen]
1921
Search space specification. Discrete lists are sampled uniformly;
2022
scipy distribution objects are sampled via their ``rvs`` method.
23+
2124
n_iter : int, default=10
2225
Number of parameter sets to evaluate.
26+
2327
random_state : int | np.random.RandomState | None, default=None
2428
Controls the pseudo-random generator for reproducibility.
29+
2530
error_score : float, default=np.nan
2631
Score assigned when the experiment raises an exception.
32+
33+
backend : {"dask", "loky", "multiprocessing", "threading", "ray"}, default = "None".
34+
Parallelization backend to use in the search process.
35+
36+
- "None": executes loop sequentally, simple list comprehension
37+
- "loky", "multiprocessing" and "threading": uses ``joblib.Parallel`` loops
38+
- "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
39+
- "dask": uses ``dask``, requires ``dask`` package in environment
40+
- "ray": uses ``ray``, requires ``ray`` package in environment
41+
42+
backend_params : dict, optional
43+
additional parameters passed to the backend as config.
44+
Directly passed to ``utils.parallel.parallelize``.
45+
Valid keys depend on the value of ``backend``:
46+
47+
- "None": no additional parameters, ``backend_params`` is ignored
48+
- "loky", "multiprocessing" and "threading": default ``joblib`` backends
49+
any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
50+
with the exception of ``backend`` which is directly controlled by ``backend``.
51+
If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
52+
will default to ``joblib`` defaults.
53+
- "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
54+
any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
55+
``backend`` must be passed as a key of ``backend_params`` in this case.
56+
If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
57+
will default to ``joblib`` defaults.
58+
- "dask": any valid keys for ``dask.compute`` can be passed, e.g., ``scheduler``
59+
60+
- "ray": The following keys can be passed:
61+
62+
- "ray_remote_args": dictionary of valid keys for ``ray.init``
63+
- "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
64+
down after parallelization.
65+
- "logger_name": str, default="ray"; name of the logger to use.
66+
- "mute_warnings": bool, default=False; if True, suppresses warnings
67+
2768
experiment : BaseExperiment, optional
2869
Callable returning a scalar score when invoked with keyword
2970
arguments matching a parameter set.
@@ -44,13 +85,17 @@ def __init__(
4485
n_iter=10,
4586
random_state=None,
4687
error_score=np.nan,
88+
backend="None",
89+
backend_params=None,
4790
experiment=None,
4891
):
4992
self.experiment = experiment
5093
self.param_distributions = param_distributions
5194
self.n_iter = n_iter
5295
self.random_state = random_state
5396
self.error_score = error_score
97+
self.backend = backend
98+
self.backend_params = backend_params
5499

55100
super().__init__()
56101

@@ -67,7 +112,7 @@ def _check_param_distributions(self, param_distributions):
67112
for p in param_distributions:
68113
for name, v in p.items():
69114
if self._is_distribution(v):
70-
# Assume scipy frozen distribution - nothing to check
115+
# Assume scipy frozen distribution: nothing to check
71116
continue
72117

73118
if isinstance(v, np.ndarray) and v.ndim > 1:
@@ -93,6 +138,8 @@ def _solve(
93138
n_iter,
94139
random_state,
95140
error_score,
141+
backend,
142+
backend_params,
96143
):
97144
"""Sample ``n_iter`` points and return the best parameter set."""
98145
self._check_param_distributions(param_distributions)
@@ -104,13 +151,18 @@ def _solve(
104151
)
105152
candidate_params = list(sampler)
106153

107-
scores: list[float] = []
108-
for candidate_param in candidate_params:
109-
try:
110-
score = experiment(**candidate_param)
111-
except Exception: # noqa: B904
112-
score = error_score
113-
scores.append(score)
154+
meta = {
155+
"experiment": experiment,
156+
"error_score": error_score,
157+
}
158+
159+
scores = parallelize(
160+
fun=_score_params,
161+
iter=candidate_params,
162+
meta=meta,
163+
backend=backend,
164+
backend_params=backend_params,
165+
)
114166

115167
best_index = int(np.argmin(scores)) # lower-is-better convention
116168
best_params = candidate_params[best_index]
@@ -154,4 +206,15 @@ def get_test_params(cls, parameter_set: str = "default"):
154206
"random_state": 0,
155207
}
156208

157-
return [params_sklearn, params_ackley]
209+
params = [params_sklearn, params_ackley]
210+
211+
from hyperactive.utils.parallel import _get_parallel_test_fixtures
212+
213+
parallel_fixtures = _get_parallel_test_fixtures()
214+
215+
for x in parallel_fixtures:
216+
new_ackley = params_ackley.copy()
217+
new_ackley.update(x)
218+
params.append(new_ackley)
219+
220+
return params

0 commit comments

Comments
 (0)