Skip to content

Commit b873c18

Browse files
committed
Use array_to_params to translate categoricals back for res()
1 parent 1974038 commit b873c18

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

bayes_opt/target_space.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def res(self) -> list[dict[str, Any]]:
672672

673673
return [{"target": target, "params": param} for target, param in zip(self.target, params)]
674674

675-
params = [dict(zip(self.keys, p)) for p in self.params]
675+
params = [self.array_to_params(p) for p in self.params]
676676

677677
return [
678678
{"target": target, "constraint": constraint_value, "params": param, "allowed": allowed}

tests/test_target_space.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,59 @@ def test_res():
293293
assert space.res() == expected_res
294294

295295

296+
def test_res_categorical() -> None:
297+
PBOUNDS = {"p1": (0, 10), "p2": ["a", "b", "c"]}
298+
299+
def _f(p1: float, p2: str) -> float:
300+
return p1 + len(p2)
301+
302+
space = TargetSpace(_f, PBOUNDS)
303+
304+
assert space.res() == []
305+
space.probe(params={"p1": 1, "p2": "a"})
306+
space.probe(params={"p1": 5, "p2": "b"})
307+
space.probe(params={"p1": 2, "p2": "c"})
308+
space.probe(params={"p1": 2, "p2": "a"})
309+
310+
expected_res = [
311+
{"params": {"p1": 1, "p2": "a"}, "target": 2},
312+
{"params": {"p1": 5, "p2": "b"}, "target": 6},
313+
{"params": {"p1": 2, "p2": "c"}, "target": 3},
314+
{"params": {"p1": 2, "p2": "a"}, "target": 3},
315+
]
316+
assert len(space.res()) == 4
317+
assert space.res() == expected_res
318+
319+
320+
def test_res_categorical_with_constraints() -> None:
321+
PBOUNDS = {"p1": (0, 10), "p2": ["a", "b", "c"]}
322+
323+
def _f(p1: float, p2: str) -> float:
324+
return p1 + len(p2)
325+
326+
space = TargetSpace(
327+
_f,
328+
PBOUNDS,
329+
constraint=NonlinearConstraint(lambda p1, p2: p1 - 2, 0, 5)
330+
)
331+
332+
assert space.res() == []
333+
space.probe(params={"p1": 1, "p2": "a"})
334+
space.probe(params={"p1": 5, "p2": "b"})
335+
space.probe(params={"p1": 2, "p2": "c"})
336+
space.probe(params={"p1": 2, "p2": "a"})
337+
338+
expected_res = [
339+
{"params": {"p1": 1, "p2": "a"}, "target": 2, 'allowed': False, 'constraint': -1},
340+
{"params": {"p1": 5, "p2": "b"}, "target": 6, 'allowed': True, 'constraint': 3},
341+
{"params": {"p1": 2, "p2": "c"}, "target": 3, 'allowed': True, 'constraint': 0},
342+
{"params": {"p1": 2, "p2": "a"}, "target": 3, 'allowed': True, 'constraint': 0},
343+
]
344+
assert len(space.res()) == 4
345+
assert space.res() == expected_res
346+
347+
348+
296349
def test_set_bounds():
297350
pbounds = {"p1": (0, 1), "p3": (0, 3), "p2": (0, 2), "p4": (0, 4)}
298351
space = TargetSpace(target_func, pbounds)

0 commit comments

Comments
 (0)