Skip to content

Commit 1a4f76d

Browse files
committed
create separate module for gfo-adapter
1 parent 0e9d8ab commit 1a4f76d

4 files changed

Lines changed: 84 additions & 13 deletions

File tree

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Adapters for individual packages."""
2+
3+
# copyright: hyperactive developers, MIT License (see LICENSE file)
4+
5+
from ._gfo import _BaseGFOadapter

src/hyperactive/opt/_adapters/_gfo.py renamed to src/hyperactive/opt/_adapters/_gfo/_gfo.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from hyperactive.base import BaseOptimizer
77
from skbase.utils.stdout_mute import StdoutMute
88

9+
from ._objective_function import ObjectiveFunction
10+
911
__all__ = ["_BaseGFOadapter"]
1012

1113

@@ -58,7 +60,10 @@ def get_search_config(self):
5860

5961
search_config = self._handle_gfo_defaults(search_config)
6062

61-
search_config["search_space"] = self._to_dict_np(search_config["search_space"])
63+
self.search_space_hyper = search_config["search_space"]
64+
search_config["search_space"] = self._conv_search_space(
65+
search_config["search_space"]
66+
)
6267

6368
return search_config
6469

@@ -85,6 +90,18 @@ def _handle_gfo_defaults(self, search_config):
8590

8691
return search_config
8792

93+
@staticmethod
94+
def _conv_search_space(search_space):
95+
# convert hyper search-space into gfo search-space
96+
search_space_gfo = {}
97+
for key in search_space.keys():
98+
search_space_gfo[key] = np.array(range(len(search_space[key])))
99+
return search_space_gfo
100+
101+
@staticmethod
102+
def _conv_objective_function(objective_function, search_space):
103+
return ObjectiveFunction(objective_function).convert(search_space)
104+
88105
def _to_dict_np(self, search_space):
89106
"""Coerce the search space to a format suitable for gfo optimizers.
90107
@@ -108,7 +125,7 @@ def coerce_to_numpy(arr):
108125
if not isinstance(arr, np.ndarray):
109126
return np.array(arr)
110127
return arr
111-
128+
112129
coerced_search_space = {k: coerce_to_numpy(v) for k, v in search_space.items()}
113130
return coerced_search_space
114131

@@ -129,23 +146,18 @@ def _run(self, experiment, **search_config):
129146
n_iter = search_config.pop("n_iter", 100)
130147
max_time = search_config.pop("max_time", None)
131148

132-
# convert hyper search-space into gfo search-space
133-
search_space_hyper = search_config["search_space"]
134-
search_space_gfo = {}
135-
for key in search_space_hyper.keys():
136-
search_space_gfo[key] = np.array(range(len(search_space_hyper[key])))
137-
search_config["search_space"] = search_space_gfo
138-
139149
gfo_cls = self._get_gfo_class()
140-
hcopt = gfo_cls(**search_config)
150+
opt = gfo_cls(**search_config)
151+
152+
score = self._conv_objective_function(experiment, self.search_space_hyper)
141153

142154
with StdoutMute(active=not self.verbose):
143-
hcopt.search(
144-
objective_function=experiment.score,
155+
opt.search(
156+
objective_function=score,
145157
n_iter=n_iter,
146158
max_time=max_time,
147159
)
148-
best_params = hcopt.best_para
160+
best_params = opt.best_para
149161
return best_params
150162

151163
@classmethod
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Author: Simon Blanke
2+
# Email: simon.blanke@yahoo.com
3+
# License: MIT License
4+
5+
6+
from .dictionary import DictClass
7+
8+
9+
def gfo2hyper(search_space, para):
10+
values_dict = {}
11+
for _, key in enumerate(search_space.keys()):
12+
pos_ = int(para[key])
13+
values_dict[key] = search_space[key][pos_]
14+
15+
return values_dict
16+
17+
18+
class ObjectiveFunction(DictClass):
19+
def __init__(self, objective_function):
20+
super().__init__()
21+
22+
self.objective_function = objective_function
23+
24+
def run_callbacks(self, type_):
25+
if self.callbacks and type_ in self.callbacks:
26+
[callback(self) for callback in self.callbacks[type_]]
27+
28+
def convert(self, search_space):
29+
# wrapper for GFOs
30+
def _model(para):
31+
para = gfo2hyper(search_space, para)
32+
self.para_dict = para
33+
34+
return self.objective_function(self)
35+
36+
_model.__name__ = self.objective_function.__name__
37+
return _model
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Author: Simon Blanke
2+
# Email: simon.blanke@yahoo.com
3+
# License: MIT License
4+
5+
6+
class DictClass:
7+
def __init__(self):
8+
self.para_dict = {}
9+
10+
def __getitem__(self, key):
11+
return self.para_dict[key]
12+
13+
def keys(self):
14+
return self.para_dict.keys()
15+
16+
def values(self):
17+
return self.para_dict.values()

0 commit comments

Comments
 (0)