@@ -136,13 +136,15 @@ def _iteration_batch(self, batch_size):
136136
137137 # Check storage for cached results (preserving original order)
138138 scores = [None ] * n_positions
139+ metrics_list = [{}] * n_positions
139140 uncached_indices = []
140141
141142 if self ._storage is not None :
142143 for i , pos in enumerate (positions ):
143144 cached = self ._storage .get (tuple (pos ))
144145 if cached is not None :
145146 scores [i ] = cached .score
147+ metrics_list [i ] = cached .metrics
146148 else :
147149 uncached_indices .append (i )
148150 else :
@@ -156,16 +158,17 @@ def _iteration_batch(self, batch_size):
156158 params_batch .append (self .conv .value2para (value ))
157159
158160 t_start = time .time ()
159- new_scores = self ._distributed_func (params_batch )
161+ raw_results = self ._distributed_func (params_batch )
160162 eval_time = time .time () - t_start
161163
162- if self . optimum == "minimum" :
163- new_scores = [ - s for s in new_scores ]
164-
165- for idx , score in zip ( uncached_indices , new_scores ):
164+ for idx , raw in zip ( uncached_indices , raw_results ) :
165+ score , metrics = self . _unpack_result ( raw )
166+ if self . optimum == "minimum" :
167+ score = - score
166168 scores [idx ] = score
169+ metrics_list [idx ] = metrics
167170 if self ._storage is not None :
168- self ._storage .put (tuple (positions [idx ]), Result (score , {} ))
171+ self ._storage .put (tuple (positions [idx ]), Result (score , metrics ))
169172
170173 per_eval_time = eval_time / len (uncached_indices )
171174 else :
@@ -179,25 +182,50 @@ def _iteration_batch(self, batch_size):
179182 uncached_set = set (uncached_indices )
180183 for i , (pos , score ) in enumerate (zip (positions , scores )):
181184 et = per_eval_time if i in uncached_set else 0
182- self ._track_evaluation (pos , score , et , per_iter_time )
185+ self ._track_evaluation (
186+ pos , score , et , per_iter_time , metrics = metrics_list [i ]
187+ )
188+
189+ @staticmethod
190+ def _unpack_result (raw ):
191+ """Separate a worker return value into score and metrics.
192+
193+ Objective functions may return a plain float or a (float, dict) tuple.
194+ The backends pass through whatever the function returns, so we unpack
195+ here at the boundary between worker results and the search loop.
196+ """
197+ if isinstance (raw , tuple ):
198+ return raw [0 ], raw [1 ]
199+ return raw , {}
183200
184- def _track_evaluation (self , pos , score , eval_time = 0 , iter_time = 0 ):
201+ def _track_evaluation (self , pos , score , eval_time = 0 , iter_time = 0 , metrics = None ):
185202 """Record a single evaluation result across all tracking systems."""
203+ if metrics is None :
204+ metrics = {}
186205 self .eval_times .append (eval_time )
187206 self .iter_times .append (iter_time )
188- self .results_manager .add (Result (score , {} ), pos )
207+ self .results_manager .add (Result (score , metrics ), pos )
189208 self .pos_l .append (pos )
190209 self .score_l .append (score )
191210 self .p_bar .update (score , pos , self .nth_iter )
192- self ._tracker .track (pos , score , {}, is_init = False )
211+ self ._last_metrics = metrics
212+ self ._tracker .track (pos , score , metrics , is_init = False )
213+ # Feed each evaluation to the stopper so early_stopping counts
214+ # individual evaluations, not batches
215+ self .stopper .update (score , self .p_bar .score_best , self ._iter )
193216 self .n_iter_total += 1
194217 self .n_iter_search += 1
195218 self ._iter += 1
196219
197220 def _check_stop (self , n_evaluated ):
198221 """Check stopping conditions and run callbacks. Returns True if should stop."""
199- current_score = self .score_l [- 1 ] if self .score_l else - math .inf
200- self .stopper .update (current_score , self .p_bar .score_best , n_evaluated - 1 )
222+ # In serial mode, the stopper needs per-evaluation updates here.
223+ # In distributed mode, _track_evaluation already updated the stopper
224+ # for each evaluation in the batch, giving correct granularity for
225+ # early_stopping counters.
226+ if not self ._is_distributed :
227+ current_score = self .score_l [- 1 ] if self .score_l else - math .inf
228+ self .stopper .update (current_score , self .p_bar .score_best , n_evaluated - 1 )
201229
202230 if self .stopper .should_stop ():
203231 if "debug_stop" in self .verbosity :
@@ -248,7 +276,7 @@ def _submit_one():
248276 cached = self ._storage .get (tuple (pos ))
249277 if cached is not None :
250278 self ._evaluate_batch ([pos ], [cached .score ])
251- self ._track_evaluation (pos , cached .score )
279+ self ._track_evaluation (pos , cached .score , metrics = cached . metrics )
252280 n_evaluated += 1
253281 return 1
254282
@@ -267,19 +295,20 @@ def _submit_one():
267295 # Process results as they arrive
268296 while pending and n_evaluated < n_iter :
269297 t_iter = time .time ()
270- completed , score = backend ._wait_any (list (pending .keys ()))
298+ completed , raw = backend ._wait_any (list (pending .keys ()))
271299 pos = pending .pop (completed )
272300
301+ score , metrics = self ._unpack_result (raw )
273302 if self .optimum == "minimum" :
274303 score = - score
275304
276305 if self ._storage is not None :
277- self ._storage .put (tuple (pos ), Result (score , {} ))
306+ self ._storage .put (tuple (pos ), Result (score , metrics ))
278307
279308 self .nth_iter = n_evaluated
280309 self ._evaluate_batch ([pos ], [score ])
281310 iter_time = time .time () - t_iter
282- self ._track_evaluation (pos , score , iter_time , iter_time )
311+ self ._track_evaluation (pos , score , iter_time , iter_time , metrics = metrics )
283312 n_evaluated += 1
284313
285314 if self ._check_stop (n_evaluated ):
@@ -315,13 +344,15 @@ def _run_batch_async(self, n_iter, nth_trial):
315344
316345 # Check storage cache (same logic as sync _iteration_batch)
317346 scores = [None ] * n_positions
347+ metrics_list = [{}] * n_positions
318348 uncached_indices = []
319349
320350 if self ._storage is not None :
321351 for i , pos in enumerate (positions ):
322352 cached = self ._storage .get (tuple (pos ))
323353 if cached is not None :
324354 scores [i ] = cached .score
355+ metrics_list [i ] = cached .metrics
325356 else :
326357 uncached_indices .append (i )
327358 else :
@@ -339,15 +370,17 @@ def _run_batch_async(self, n_iter, nth_trial):
339370 # Collect all results (async within batch)
340371 t_start = time .time ()
341372 while futures :
342- completed , score = backend ._wait_any (list (futures .keys ()))
373+ completed , raw = backend ._wait_any (list (futures .keys ()))
343374 idx = futures .pop (completed )
344375
376+ score , metrics = self ._unpack_result (raw )
345377 if self .optimum == "minimum" :
346378 score = - score
347379
348380 scores [idx ] = score
381+ metrics_list [idx ] = metrics
349382 if self ._storage is not None :
350- self ._storage .put (tuple (positions [idx ]), Result (score , {} ))
383+ self ._storage .put (tuple (positions [idx ]), Result (score , metrics ))
351384
352385 per_eval_time = (time .time () - t_start ) / len (uncached_indices )
353386 else :
@@ -361,7 +394,9 @@ def _run_batch_async(self, n_iter, nth_trial):
361394 uncached_set = set (uncached_indices )
362395 for i , (pos , score ) in enumerate (zip (positions , scores )):
363396 et = per_eval_time if i in uncached_set else 0
364- self ._track_evaluation (pos , score , et , per_iter_time )
397+ self ._track_evaluation (
398+ pos , score , et , per_iter_time , metrics = metrics_list [i ]
399+ )
365400
366401 n_evaluated += n_positions
367402
@@ -681,27 +716,9 @@ def _init_search(
681716 # Extract original function for single-point use during init
682717 objective_function = objective_function ._gfo_original_func
683718
684- # The objective may return (score, metrics) tuples, but the
685- # distributed batch interface only passes scores between
686- # workers and the search loop. Normalize here so all
687- # distributed paths (sync batch, true-async, batch-async)
688- # receive plain floats.
689- _raw_distributed = self ._original_func
690-
691- def _normalize_return (params ):
692- out = _raw_distributed (params )
693- if isinstance (out , tuple ):
694- return out [0 ]
695- return out
696-
697- self ._original_func = _normalize_return
698-
699719 if catch :
700720 self ._original_func = wrap_with_catch (self ._original_func , catch )
701-
702- # Rebuild the wrapper so the sync batch path also uses the
703- # normalized (and optionally catch-wrapped) function
704- self ._distributed_func = self ._backend .distribute (self ._original_func )
721+ self ._distributed_func = self ._backend .distribute (self ._original_func )
705722 else :
706723 self ._batch_size = None
707724 self ._backend = None
0 commit comments