Skip to content

Commit 0c0b4d8

Browse files
committed
really fix meta-data this time
1 parent fc8e738 commit 0c0b4d8

2 files changed

Lines changed: 72 additions & 49 deletions

File tree

src/gradient_free_optimizers/distributed/_base.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,31 @@ def __init__(self, n_workers: int):
6565
self.n_workers = n_workers
6666

6767
@abstractmethod
68-
def _distribute(self, func, params_batch: list[dict]) -> list[float]:
68+
def _distribute(self, func, params_batch: list[dict]) -> list:
6969
"""Evaluate func(params) for each params dict in the batch.
7070
7171
This is the only method subclasses need to implement for
7272
synchronous (batch) evaluation. It receives the original
7373
(unwrapped) objective function and a list of parameter
74-
dictionaries, and must return a list of scores in the same order.
74+
dictionaries, and must return a list of results in the same order.
7575
7676
The function follows GFO's convention: ``func(params_dict)`` where
7777
params_dict is a single dictionary, not keyword arguments.
7878
7979
Parameters
8080
----------
8181
func : callable
82-
The original objective function with signature f(dict) -> float.
82+
The original objective function with signature
83+
f(dict) -> float or f(dict) -> (float, dict).
8384
params_batch : list[dict]
8485
Parameter dictionaries to evaluate.
8586
8687
Returns
8788
-------
88-
list[float]
89-
Scores in the same order as params_batch.
89+
list[float | tuple[float, dict]]
90+
Results in the same order as params_batch. Each element is
91+
either a plain score or a (score, metrics) tuple, matching
92+
whatever the objective function returns.
9093
"""
9194
...
9295

@@ -119,7 +122,7 @@ def _wait_any(self, futures) -> tuple:
119122
120123
Only required for async backends (``_is_async = True``).
121124
Blocks until at least one future from the collection is ready,
122-
then returns both the future object and its result value.
125+
then returns both the future object and its raw result.
123126
124127
Parameters
125128
----------
@@ -128,8 +131,9 @@ def _wait_any(self, futures) -> tuple:
128131
129132
Returns
130133
-------
131-
tuple of (future, float)
132-
The completed future object and its result (score).
134+
tuple of (future, float | tuple[float, dict])
135+
The completed future object and the objective function's
136+
return value (a plain score or a (score, metrics) tuple).
133137
The future is returned so the caller can identify which
134138
submission completed (e.g., to look up the associated position).
135139
"""
@@ -142,18 +146,20 @@ def distribute(self, func):
142146
"""Decorator that wraps a single-point objective for batch evaluation.
143147
144148
The decorated function accepts a list of parameter dicts and returns
145-
a list of scores. It also carries metadata attributes that the
149+
a list of results. It also carries metadata attributes that the
146150
optimizer's search loop uses to detect and configure batch mode.
147151
148152
Parameters
149153
----------
150154
func : callable
151-
Objective function with signature f(params_dict) -> float.
155+
Objective function with signature
156+
f(params_dict) -> float or f(params_dict) -> (float, dict).
152157
153158
Returns
154159
-------
155160
callable
156-
Wrapped function with signature f(list[dict]) -> list[float].
161+
Wrapped function with signature
162+
f(list[dict]) -> list[float | tuple[float, dict]].
157163
"""
158164

159165
def wrapper(params_batch):

src/gradient_free_optimizers/search.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,15 @@ def _iteration_batch(self, batch_size):
136136

137137
# Check storage for cached results (preserving original order)
138138
scores = [None] * n_positions
139+
metrics_list = [{}] * n_positions
139140
uncached_indices = []
140141

141142
if self._storage is not None:
142143
for i, pos in enumerate(positions):
143144
cached = self._storage.get(tuple(pos))
144145
if cached is not None:
145146
scores[i] = cached.score
147+
metrics_list[i] = cached.metrics
146148
else:
147149
uncached_indices.append(i)
148150
else:
@@ -156,16 +158,17 @@ def _iteration_batch(self, batch_size):
156158
params_batch.append(self.conv.value2para(value))
157159

158160
t_start = time.time()
159-
new_scores = self._distributed_func(params_batch)
161+
raw_results = self._distributed_func(params_batch)
160162
eval_time = time.time() - t_start
161163

162-
if self.optimum == "minimum":
163-
new_scores = [-s for s in new_scores]
164-
165-
for idx, score in zip(uncached_indices, new_scores):
164+
for idx, raw in zip(uncached_indices, raw_results):
165+
score, metrics = self._unpack_result(raw)
166+
if self.optimum == "minimum":
167+
score = -score
166168
scores[idx] = score
169+
metrics_list[idx] = metrics
167170
if self._storage is not None:
168-
self._storage.put(tuple(positions[idx]), Result(score, {}))
171+
self._storage.put(tuple(positions[idx]), Result(score, metrics))
169172

170173
per_eval_time = eval_time / len(uncached_indices)
171174
else:
@@ -179,25 +182,50 @@ def _iteration_batch(self, batch_size):
179182
uncached_set = set(uncached_indices)
180183
for i, (pos, score) in enumerate(zip(positions, scores)):
181184
et = per_eval_time if i in uncached_set else 0
182-
self._track_evaluation(pos, score, et, per_iter_time)
185+
self._track_evaluation(
186+
pos, score, et, per_iter_time, metrics=metrics_list[i]
187+
)
188+
189+
@staticmethod
190+
def _unpack_result(raw):
191+
"""Separate a worker return value into score and metrics.
192+
193+
Objective functions may return a plain float or a (float, dict) tuple.
194+
The backends pass through whatever the function returns, so we unpack
195+
here at the boundary between worker results and the search loop.
196+
"""
197+
if isinstance(raw, tuple):
198+
return raw[0], raw[1]
199+
return raw, {}
183200

184-
def _track_evaluation(self, pos, score, eval_time=0, iter_time=0):
201+
def _track_evaluation(self, pos, score, eval_time=0, iter_time=0, metrics=None):
185202
"""Record a single evaluation result across all tracking systems."""
203+
if metrics is None:
204+
metrics = {}
186205
self.eval_times.append(eval_time)
187206
self.iter_times.append(iter_time)
188-
self.results_manager.add(Result(score, {}), pos)
207+
self.results_manager.add(Result(score, metrics), pos)
189208
self.pos_l.append(pos)
190209
self.score_l.append(score)
191210
self.p_bar.update(score, pos, self.nth_iter)
192-
self._tracker.track(pos, score, {}, is_init=False)
211+
self._last_metrics = metrics
212+
self._tracker.track(pos, score, metrics, is_init=False)
213+
# Feed each evaluation to the stopper so early_stopping counts
214+
# individual evaluations, not batches
215+
self.stopper.update(score, self.p_bar.score_best, self._iter)
193216
self.n_iter_total += 1
194217
self.n_iter_search += 1
195218
self._iter += 1
196219

197220
def _check_stop(self, n_evaluated):
198221
"""Check stopping conditions and run callbacks. Returns True if should stop."""
199-
current_score = self.score_l[-1] if self.score_l else -math.inf
200-
self.stopper.update(current_score, self.p_bar.score_best, n_evaluated - 1)
222+
# In serial mode, the stopper needs per-evaluation updates here.
223+
# In distributed mode, _track_evaluation already updated the stopper
224+
# for each evaluation in the batch, giving correct granularity for
225+
# early_stopping counters.
226+
if not self._is_distributed:
227+
current_score = self.score_l[-1] if self.score_l else -math.inf
228+
self.stopper.update(current_score, self.p_bar.score_best, n_evaluated - 1)
201229

202230
if self.stopper.should_stop():
203231
if "debug_stop" in self.verbosity:
@@ -248,7 +276,7 @@ def _submit_one():
248276
cached = self._storage.get(tuple(pos))
249277
if cached is not None:
250278
self._evaluate_batch([pos], [cached.score])
251-
self._track_evaluation(pos, cached.score)
279+
self._track_evaluation(pos, cached.score, metrics=cached.metrics)
252280
n_evaluated += 1
253281
return 1
254282

@@ -267,19 +295,20 @@ def _submit_one():
267295
# Process results as they arrive
268296
while pending and n_evaluated < n_iter:
269297
t_iter = time.time()
270-
completed, score = backend._wait_any(list(pending.keys()))
298+
completed, raw = backend._wait_any(list(pending.keys()))
271299
pos = pending.pop(completed)
272300

301+
score, metrics = self._unpack_result(raw)
273302
if self.optimum == "minimum":
274303
score = -score
275304

276305
if self._storage is not None:
277-
self._storage.put(tuple(pos), Result(score, {}))
306+
self._storage.put(tuple(pos), Result(score, metrics))
278307

279308
self.nth_iter = n_evaluated
280309
self._evaluate_batch([pos], [score])
281310
iter_time = time.time() - t_iter
282-
self._track_evaluation(pos, score, iter_time, iter_time)
311+
self._track_evaluation(pos, score, iter_time, iter_time, metrics=metrics)
283312
n_evaluated += 1
284313

285314
if self._check_stop(n_evaluated):
@@ -315,13 +344,15 @@ def _run_batch_async(self, n_iter, nth_trial):
315344

316345
# Check storage cache (same logic as sync _iteration_batch)
317346
scores = [None] * n_positions
347+
metrics_list = [{}] * n_positions
318348
uncached_indices = []
319349

320350
if self._storage is not None:
321351
for i, pos in enumerate(positions):
322352
cached = self._storage.get(tuple(pos))
323353
if cached is not None:
324354
scores[i] = cached.score
355+
metrics_list[i] = cached.metrics
325356
else:
326357
uncached_indices.append(i)
327358
else:
@@ -339,15 +370,17 @@ def _run_batch_async(self, n_iter, nth_trial):
339370
# Collect all results (async within batch)
340371
t_start = time.time()
341372
while futures:
342-
completed, score = backend._wait_any(list(futures.keys()))
373+
completed, raw = backend._wait_any(list(futures.keys()))
343374
idx = futures.pop(completed)
344375

376+
score, metrics = self._unpack_result(raw)
345377
if self.optimum == "minimum":
346378
score = -score
347379

348380
scores[idx] = score
381+
metrics_list[idx] = metrics
349382
if self._storage is not None:
350-
self._storage.put(tuple(positions[idx]), Result(score, {}))
383+
self._storage.put(tuple(positions[idx]), Result(score, metrics))
351384

352385
per_eval_time = (time.time() - t_start) / len(uncached_indices)
353386
else:
@@ -361,7 +394,9 @@ def _run_batch_async(self, n_iter, nth_trial):
361394
uncached_set = set(uncached_indices)
362395
for i, (pos, score) in enumerate(zip(positions, scores)):
363396
et = per_eval_time if i in uncached_set else 0
364-
self._track_evaluation(pos, score, et, per_iter_time)
397+
self._track_evaluation(
398+
pos, score, et, per_iter_time, metrics=metrics_list[i]
399+
)
365400

366401
n_evaluated += n_positions
367402

@@ -681,27 +716,9 @@ def _init_search(
681716
# Extract original function for single-point use during init
682717
objective_function = objective_function._gfo_original_func
683718

684-
# The objective may return (score, metrics) tuples, but the
685-
# distributed batch interface only passes scores between
686-
# workers and the search loop. Normalize here so all
687-
# distributed paths (sync batch, true-async, batch-async)
688-
# receive plain floats.
689-
_raw_distributed = self._original_func
690-
691-
def _normalize_return(params):
692-
out = _raw_distributed(params)
693-
if isinstance(out, tuple):
694-
return out[0]
695-
return out
696-
697-
self._original_func = _normalize_return
698-
699719
if catch:
700720
self._original_func = wrap_with_catch(self._original_func, catch)
701-
702-
# Rebuild the wrapper so the sync batch path also uses the
703-
# normalized (and optionally catch-wrapped) function
704-
self._distributed_func = self._backend.distribute(self._original_func)
721+
self._distributed_func = self._backend.distribute(self._original_func)
705722
else:
706723
self._batch_size = None
707724
self._backend = None

0 commit comments

Comments
 (0)