Skip to content

Commit fc8e738

Browse files
committed
fix error when obj-func returns meta-data
1 parent 3a21c98 commit fc8e738

2 files changed

Lines changed: 31 additions & 9 deletions

File tree

src/gradient_free_optimizers/search.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import json
88
import math
99
import time
10-
import warnings
1110
from collections.abc import Callable
1211
from typing import TYPE_CHECKING, Any, Literal
1312

@@ -682,13 +681,27 @@ def _init_search(
682681
# Extract original function for single-point use during init
683682
objective_function = objective_function._gfo_original_func
684683

684+
# The objective may return (score, metrics) tuples, but the
685+
# distributed batch interface only passes scores between
686+
# workers and the search loop. Normalize here so all
687+
# distributed paths (sync batch, true-async, batch-async)
688+
# receive plain floats.
689+
_raw_distributed = self._original_func
690+
691+
def _normalize_return(params):
692+
out = _raw_distributed(params)
693+
if isinstance(out, tuple):
694+
return out[0]
695+
return out
696+
697+
self._original_func = _normalize_return
698+
685699
if catch:
686-
warnings.warn(
687-
"catch parameter is not yet supported with distributed "
688-
"evaluation. Ignoring catch for this search.",
689-
stacklevel=3,
690-
)
691-
catch = None
700+
self._original_func = wrap_with_catch(self._original_func, catch)
701+
702+
# Rebuild the wrapper so the sync batch path also uses the
703+
# normalized (and optionally catch-wrapped) function
704+
self._distributed_func = self._backend.distribute(self._original_func)
692705
else:
693706
self._batch_size = None
694707
self._backend = None
@@ -697,7 +710,15 @@ def _init_search(
697710
objective_function = wrap_with_catch(objective_function, catch)
698711

699712
if getattr(self, "optimum", "maximum") == "minimum":
700-
self.objective_function = lambda pos: -objective_function(pos)
713+
_obj = objective_function
714+
715+
def _negate(params):
716+
out = _obj(params)
717+
if isinstance(out, tuple):
718+
return -out[0], out[1]
719+
return -out
720+
721+
self.objective_function = _negate
701722
else:
702723
self.objective_function = objective_function
703724
self.n_iter = n_iter

src/gradient_free_optimizers/storage/_sqlite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def __init__(self, path: str):
7171
self._conn.commit()
7272

7373
def _key_to_str(self, key: tuple) -> str:
74-
return json.dumps(key)
74+
# Convert numpy integers to Python ints for JSON serialization
75+
return json.dumps([int(k) for k in key])
7576

7677
def _str_to_key(self, s: str) -> tuple:
7778
return tuple(json.loads(s))

0 commit comments

Comments
 (0)