77import json
88import math
99import time
10- import warnings
1110from collections .abc import Callable
1211from typing import TYPE_CHECKING , Any , Literal
1312
@@ -682,13 +681,27 @@ def _init_search(
682681 # Extract original function for single-point use during init
683682 objective_function = objective_function ._gfo_original_func
684683
684+ # The objective may return (score, metrics) tuples, but the
685+ # distributed batch interface only passes scores between
686+ # workers and the search loop. Normalize here so all
687+ # distributed paths (sync batch, true-async, batch-async)
688+ # receive plain floats.
689+ _raw_distributed = self ._original_func
690+
691+ def _normalize_return (params ):
692+ out = _raw_distributed (params )
693+ if isinstance (out , tuple ):
694+ return out [0 ]
695+ return out
696+
697+ self ._original_func = _normalize_return
698+
685699 if catch :
686- warnings .warn (
687- "catch parameter is not yet supported with distributed "
688- "evaluation. Ignoring catch for this search." ,
689- stacklevel = 3 ,
690- )
691- catch = None
700+ self ._original_func = wrap_with_catch (self ._original_func , catch )
701+
702+ # Rebuild the wrapper so the sync batch path also uses the
703+ # normalized (and optionally catch-wrapped) function
704+ self ._distributed_func = self ._backend .distribute (self ._original_func )
692705 else :
693706 self ._batch_size = None
694707 self ._backend = None
@@ -697,7 +710,15 @@ def _init_search(
697710 objective_function = wrap_with_catch (objective_function , catch )
698711
699712 if getattr (self , "optimum" , "maximum" ) == "minimum" :
700- self .objective_function = lambda pos : - objective_function (pos )
713+ _obj = objective_function
714+
715+ def _negate (params ):
716+ out = _obj (params )
717+ if isinstance (out , tuple ):
718+ return - out [0 ], out [1 ]
719+ return - out
720+
721+ self .objective_function = _negate
701722 else :
702723 self .objective_function = objective_function
703724 self .n_iter = n_iter
0 commit comments