77from sklearn .model_selection import ParameterGrid
88
99from hyperactive .base import BaseOptimizer
10+ from hyperactive .opt ._common import _score_params
11+ from hyperactive .utils .parallel import parallelize
1012
1113
1214class GridSearchSk (BaseOptimizer ):
@@ -17,8 +19,45 @@ class GridSearchSk(BaseOptimizer):
1719 param_grid : dict[str, list]
1820 The search space to explore. A dictionary with parameter
1921 names as keys and a numpy array as values.
22+
2023 error_score : float, default=np.nan
2124 The score to assign if an error occurs during the evaluation of a parameter set.
25+
26+ backend : {"dask", "loky", "multiprocessing", "threading", "ray"}, default = "None".
27+ Parallelization backend to use in the search process.
28+
29+ - "None": executes loop sequentally, simple list comprehension
30+ - "loky", "multiprocessing" and "threading": uses ``joblib.Parallel`` loops
31+ - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
32+ - "dask": uses ``dask``, requires ``dask`` package in environment
33+ - "ray": uses ``ray``, requires ``ray`` package in environment
34+
35+ backend_params : dict, optional
36+ additional parameters passed to the backend as config.
37+ Directly passed to ``utils.parallel.parallelize``.
38+ Valid keys depend on the value of ``backend``:
39+
40+ - "None": no additional parameters, ``backend_params`` is ignored
41+ - "loky", "multiprocessing" and "threading": default ``joblib`` backends
42+ any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
43+ with the exception of ``backend`` which is directly controlled by ``backend``.
44+ If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
45+ will default to ``joblib`` defaults.
46+ - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
47+ any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
48+ ``backend`` must be passed as a key of ``backend_params`` in this case.
49+ If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
50+ will default to ``joblib`` defaults.
51+ - "dask": any valid keys for ``dask.compute`` can be passed, e.g., ``scheduler``
52+
53+ - "ray": The following keys can be passed:
54+
55+ - "ray_remote_args": dictionary of valid keys for ``ray.init``
56+ - "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
57+ down after parallelization.
58+ - "logger_name": str, default="ray"; name of the logger to use.
59+ - "mute_warnings": bool, default=False; if True, suppresses warnings
60+
2261 experiment : BaseExperiment, optional
2362 The experiment to optimize parameters for.
2463 Optional, can be passed later via ``set_params``.
@@ -53,17 +92,29 @@ class GridSearchSk(BaseOptimizer):
5392
5493 Best parameters can also be accessed via the attributes:
5594 >>> best_params = grid_search.best_params_
95+
96+ To parallelize the search, set the ``backend`` and ``backend_params``:
97+ >>> grid_search = GridSearch(
98+ ... param_grid,
99+ ... backend="joblib",
100+ ... backend_params={"n_jobs": -1},
101+ ... experiment=sklearn_exp,
102+ ... )
56103 """
57104
58105 def __init__ (
59106 self ,
60107 param_grid = None ,
61108 error_score = np .nan ,
109+ backend = "None" ,
110+ backend_params = None ,
62111 experiment = None ,
63112 ):
64113 self .experiment = experiment
65114 self .param_grid = param_grid
66115 self .error_score = error_score
116+ self .backend = backend
117+ self .backend_params = backend_params
67118
68119 super ().__init__ ()
69120
@@ -91,19 +142,23 @@ def _check_param_grid(self, param_grid):
91142 "to be a non-empty sequence."
92143 )
93144
94- def _solve (self , experiment , param_grid , error_score ):
145+ def _solve (self , experiment , param_grid , error_score , backend , backend_params ):
95146 """Run the optimization search process."""
96147 self ._check_param_grid (param_grid )
97148 candidate_params = list (ParameterGrid (param_grid ))
98149
99- scores = []
100- for candidate_param in candidate_params :
101- try :
102- score = experiment (** candidate_param )
103- except Exception : # noqa: B904
104- # Catch all exceptions and assign error_score
105- score = error_score
106- scores .append (score )
150+ meta = {
151+ "experiment" : experiment ,
152+ "error_score" : error_score ,
153+ }
154+
155+ scores = parallelize (
156+ fun = _score_params ,
157+ iter = candidate_params ,
158+ meta = meta ,
159+ backend = backend ,
160+ backend_params = backend_params ,
161+ )
107162
108163 best_index = np .argmin (scores )
109164 best_params = candidate_params [best_index ]
@@ -170,4 +225,15 @@ def get_test_params(cls, parameter_set="default"):
170225 "param_grid" : param_grid ,
171226 }
172227
173- return [params_sklearn , params_ackley ]
228+ params = [params_sklearn , params_ackley ]
229+
230+ from hyperactive .utils .parallel import _get_parallel_test_fixtures
231+
232+ parallel_fixtures = _get_parallel_test_fixtures ()
233+
234+ for x in parallel_fixtures :
235+ new_ackley = params_ackley .copy ()
236+ new_ackley .update (x )
237+ params .append (new_ackley )
238+
239+ return params
0 commit comments