Skip to content

Commit 1974038

Browse files
committed
Use array_to_params to translate categoricals back for max()
1 parent 25084fe commit 1974038

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

bayes_opt/target_space.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def max(self) -> dict[str, Any] | None:
643643
params = self.params[self.mask]
644644
target_max_idx = np.argmax(target)
645645

646-
res = {"target": target_max, "params": dict(zip(self.keys, params[target_max_idx]))}
646+
res = {"target": target_max, "params": self.array_to_params(params[target_max_idx])}
647647

648648
if self._constraint is not None:
649649
constraint_values = self.constraint_values[self.mask]

tests/test_target_space.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,26 @@ def test_max_with_constraint_identical_target_value():
253253
assert space.max() == {"params": {"p1": 2, "p2": 3}, "target": 5, "constraint": -1}
254254

255255

256+
def test_max_categorical() -> None:
257+
PBOUNDS = {
258+
"first_float": (0.0, 1.0),
259+
"categorical_value": ("a", "b", "c", "d"),
260+
"second_float": (0.0, 1.0),
261+
}
262+
263+
def _f(first_float: float, categorical_value: str, second_float: float) -> float:
264+
return second_float if categorical_value == "c" else first_float
265+
266+
space = TargetSpace(_f, PBOUNDS)
267+
space.probe(params={"first_float": 0.1, "categorical_value": "a", "second_float": 0.1})
268+
space.probe(params={"first_float": 0.1, "categorical_value": "b", "second_float": 0.9})
269+
space.probe(params={"first_float": 0.1, "categorical_value": "c", "second_float": 0.8})
270+
space.probe(params={"first_float": 0.1, "categorical_value": "d", "second_float": 0.9})
271+
272+
expected = {"first_float": 0.1, "categorical_value": "c", "second_float": 0.8}
273+
assert space.max()["params"] == expected
274+
275+
256276
def test_res():
257277
PBOUNDS = {"p1": (0, 10), "p2": (1, 100)}
258278
space = TargetSpace(target_func, PBOUNDS)

0 commit comments

Comments
 (0)