Skip to content

Commit 9b57a6f

Browse files
committed
Use array_to_params to translate categoricals back for res()
1 parent 1974038 commit 9b57a6f

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

bayes_opt/target_space.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def res(self) -> list[dict[str, Any]]:
672672

673673
return [{"target": target, "params": param} for target, param in zip(self.target, params)]
674674

675-
params = [dict(zip(self.keys, p)) for p in self.params]
675+
params = [self.array_to_params(p) for p in self.params]
676676

677677
return [
678678
{"target": target, "constraint": constraint_value, "params": param, "allowed": allowed}

tests/test_target_space.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,54 @@ def test_res():
293293
assert space.res() == expected_res
294294

295295

296+
def test_res_categorical() -> None:
297+
PBOUNDS = {"p1": (0, 10), "p2": ["a", "b", "c"]}
298+
299+
def _f(p1: float, p2: str) -> float:
300+
return p1 + len(p2)
301+
302+
space = TargetSpace(_f, PBOUNDS)
303+
304+
assert space.res() == []
305+
space.probe(params={"p1": 1, "p2": "a"})
306+
space.probe(params={"p1": 5, "p2": "b"})
307+
space.probe(params={"p1": 2, "p2": "c"})
308+
space.probe(params={"p1": 2, "p2": "a"})
309+
310+
expected_res = [
311+
{"params": {"p1": 1, "p2": "a"}, "target": 2},
312+
{"params": {"p1": 5, "p2": "b"}, "target": 6},
313+
{"params": {"p1": 2, "p2": "c"}, "target": 3},
314+
{"params": {"p1": 2, "p2": "a"}, "target": 3},
315+
]
316+
assert len(space.res()) == 4
317+
assert space.res() == expected_res
318+
319+
320+
def test_res_categorical_with_constraints() -> None:
321+
PBOUNDS = {"p1": (0, 10), "p2": ["a", "b", "c"]}
322+
323+
def _f(p1: float, p2: str) -> float:
324+
return p1 + len(p2)
325+
326+
space = TargetSpace(_f, PBOUNDS, constraint=NonlinearConstraint(lambda p1, p2: p1 - 2, 0, 5))
327+
328+
assert space.res() == []
329+
space.probe(params={"p1": 1, "p2": "a"})
330+
space.probe(params={"p1": 5, "p2": "b"})
331+
space.probe(params={"p1": 2, "p2": "c"})
332+
space.probe(params={"p1": 2, "p2": "a"})
333+
334+
expected_res = [
335+
{"params": {"p1": 1, "p2": "a"}, "target": 2, "allowed": False, "constraint": -1},
336+
{"params": {"p1": 5, "p2": "b"}, "target": 6, "allowed": True, "constraint": 3},
337+
{"params": {"p1": 2, "p2": "c"}, "target": 3, "allowed": True, "constraint": 0},
338+
{"params": {"p1": 2, "p2": "a"}, "target": 3, "allowed": True, "constraint": 0},
339+
]
340+
assert len(space.res()) == 4
341+
assert space.res() == expected_res
342+
343+
296344
def test_set_bounds():
297345
pbounds = {"p1": (0, 1), "p3": (0, 3), "p2": (0, 2), "p4": (0, 4)}
298346
space = TargetSpace(target_func, pbounds)

0 commit comments

Comments
 (0)