Skip to content

Commit 17ff989

Browse files
Fix .max() for categorical (#601)
1 parent 25084fe commit 17ff989

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

bayes_opt/target_space.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def max(self) -> dict[str, Any] | None:
643643
params = self.params[self.mask]
644644
target_max_idx = np.argmax(target)
645645

646-
res = {"target": target_max, "params": dict(zip(self.keys, params[target_max_idx]))}
646+
res = {"target": target_max, "params": self.array_to_params(params[target_max_idx])}
647647

648648
if self._constraint is not None:
649649
constraint_values = self.constraint_values[self.mask]
@@ -672,7 +672,7 @@ def res(self) -> list[dict[str, Any]]:
672672

673673
return [{"target": target, "params": param} for target, param in zip(self.target, params)]
674674

675-
params = [dict(zip(self.keys, p)) for p in self.params]
675+
params = [self.array_to_params(p) for p in self.params]
676676

677677
return [
678678
{"target": target, "constraint": constraint_value, "params": param, "allowed": allowed}

tests/test_target_space.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,26 @@ def test_max_with_constraint_identical_target_value():
253253
assert space.max() == {"params": {"p1": 2, "p2": 3}, "target": 5, "constraint": -1}
254254

255255

256+
def test_max_categorical() -> None:
257+
PBOUNDS = {
258+
"first_float": (0.0, 1.0),
259+
"categorical_value": ("a", "b", "c", "d"),
260+
"second_float": (0.0, 1.0),
261+
}
262+
263+
def _f(first_float: float, categorical_value: str, second_float: float) -> float:
264+
return second_float if categorical_value == "c" else first_float
265+
266+
space = TargetSpace(_f, PBOUNDS)
267+
space.probe(params={"first_float": 0.1, "categorical_value": "a", "second_float": 0.1})
268+
space.probe(params={"first_float": 0.1, "categorical_value": "b", "second_float": 0.9})
269+
space.probe(params={"first_float": 0.1, "categorical_value": "c", "second_float": 0.8})
270+
space.probe(params={"first_float": 0.1, "categorical_value": "d", "second_float": 0.9})
271+
272+
expected = {"first_float": 0.1, "categorical_value": "c", "second_float": 0.8}
273+
assert space.max()["params"] == expected
274+
275+
256276
def test_res():
257277
PBOUNDS = {"p1": (0, 10), "p2": (1, 100)}
258278
space = TargetSpace(target_func, PBOUNDS)
@@ -273,6 +293,54 @@ def test_res():
273293
assert space.res() == expected_res
274294

275295

296+
def test_res_categorical() -> None:
297+
PBOUNDS = {"p1": (0, 10), "p2": ["a", "b", "c"]}
298+
299+
def _f(p1: float, p2: str) -> float:
300+
return p1 + len(p2)
301+
302+
space = TargetSpace(_f, PBOUNDS)
303+
304+
assert space.res() == []
305+
space.probe(params={"p1": 1, "p2": "a"})
306+
space.probe(params={"p1": 5, "p2": "b"})
307+
space.probe(params={"p1": 2, "p2": "c"})
308+
space.probe(params={"p1": 2, "p2": "a"})
309+
310+
expected_res = [
311+
{"params": {"p1": 1, "p2": "a"}, "target": 2},
312+
{"params": {"p1": 5, "p2": "b"}, "target": 6},
313+
{"params": {"p1": 2, "p2": "c"}, "target": 3},
314+
{"params": {"p1": 2, "p2": "a"}, "target": 3},
315+
]
316+
assert len(space.res()) == 4
317+
assert space.res() == expected_res
318+
319+
320+
def test_res_categorical_with_constraints() -> None:
321+
PBOUNDS = {"p1": (0, 10), "p2": ["a", "b", "c"]}
322+
323+
def _f(p1: float, p2: str) -> float:
324+
return p1 + len(p2)
325+
326+
space = TargetSpace(_f, PBOUNDS, constraint=NonlinearConstraint(lambda p1, p2: p1 - 2, 0, 5))
327+
328+
assert space.res() == []
329+
space.probe(params={"p1": 1, "p2": "a"})
330+
space.probe(params={"p1": 5, "p2": "b"})
331+
space.probe(params={"p1": 2, "p2": "c"})
332+
space.probe(params={"p1": 2, "p2": "a"})
333+
334+
expected_res = [
335+
{"params": {"p1": 1, "p2": "a"}, "target": 2, "allowed": False, "constraint": -1},
336+
{"params": {"p1": 5, "p2": "b"}, "target": 6, "allowed": True, "constraint": 3},
337+
{"params": {"p1": 2, "p2": "c"}, "target": 3, "allowed": True, "constraint": 0},
338+
{"params": {"p1": 2, "p2": "a"}, "target": 3, "allowed": True, "constraint": 0},
339+
]
340+
assert len(space.res()) == 4
341+
assert space.res() == expected_res
342+
343+
276344
def test_set_bounds():
277345
pbounds = {"p1": (0, 1), "p3": (0, 3), "p2": (0, 2), "p4": (0, 4)}
278346
space = TargetSpace(target_func, pbounds)

0 commit comments

Comments
 (0)