Skip to content

Commit d2695f7

Browse files
committed
unify unpacking of obj. func. result into one place
1 parent 99a2d6e commit d2695f7

File tree

3 files changed

+70
-18
lines changed

3 files changed

+70
-18
lines changed

src/gradient_free_optimizers/_objective_adapter.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from ._result import Result
1+
from ._result import Result, unpack_objective_result
22

33

44
class ObjectiveAdapter:
5-
"""Maps *pos* → (score, metrics, params)."""
5+
"""Maps *pos* → (Result, params)."""
66

77
def __init__(self, conv, objective):
88
self._conv = conv
@@ -13,11 +13,7 @@ def _call_objective(self, pos):
1313
params = self._conv.value2para(self._conv.position2value(pos))
1414
out = self._objective(params)
1515

16-
if isinstance(out, tuple):
17-
score, metrics = out
18-
else:
19-
score, metrics = float(out), {}
20-
16+
score, metrics = unpack_objective_result(out)
2117
result = Result(score, metrics)
2218

2319
return result, params
Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,67 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22

33

44
@dataclass
55
class Result:
6+
"""Internal result container used throughout the search pipeline."""
7+
68
score: float
79
metrics: dict
10+
11+
12+
@dataclass
13+
class ObjectiveResult:
14+
"""Explicit return type for objective functions.
15+
16+
Avoids the ambiguity of bare tuples, which becomes critical when the
17+
return value needs to carry structured data alongside the score (e.g.,
18+
custom metrics). A plain ``(float, dict)`` tuple works today but
19+
collides with future multi-objective returns like ``(float, float)``.
20+
21+
Parameters
22+
----------
23+
score : float
24+
The objective function score.
25+
metrics : dict, optional
26+
Custom metrics to record alongside the score.
27+
28+
Examples
29+
--------
30+
>>> def objective(params):
31+
... loss = params["x"] ** 2
32+
... return ObjectiveResult(score=-loss, metrics={"raw_loss": loss})
33+
"""
34+
35+
score: float
36+
metrics: dict = field(default_factory=dict)
37+
38+
39+
def unpack_objective_result(raw) -> tuple[float, dict]:
40+
"""Extract score and metrics from an objective function's return value.
41+
42+
Single entry point for parsing objective output. All code paths that
43+
receive raw objective returns (serial adapter, distributed unpacking,
44+
minimization wrapper) should call this instead of doing their own
45+
isinstance checks.
46+
47+
Supported return conventions (checked in this order):
48+
49+
1. ``ObjectiveResult`` instance (preferred, unambiguous)
50+
2. ``(float, dict)`` tuple (legacy convention)
51+
3. ``float`` (score only, no metrics)
52+
53+
Parameters
54+
----------
55+
raw : float or tuple or ObjectiveResult
56+
Raw return value from an objective function.
57+
58+
Returns
59+
-------
60+
tuple[float, dict]
61+
The (score, metrics) pair.
62+
"""
63+
if isinstance(raw, ObjectiveResult):
64+
return raw.score, raw.metrics
65+
if isinstance(raw, tuple):
66+
return raw[0], raw[1]
67+
return float(raw), {}

src/gradient_free_optimizers/search.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ._objective_adapter import ObjectiveAdapter
1818
from ._print_info import print_summary
1919
from ._progress_bar import ProgressBarLVL0, ProgressBarLVL1
20-
from ._result import Result
20+
from ._result import Result, unpack_objective_result
2121
from ._results_manager import ResultsManager
2222
from ._search_statistics import SearchStatistics
2323
from ._stopping_conditions import OptimizationStopper
@@ -190,13 +190,10 @@ def _iteration_batch(self, batch_size):
190190
def _unpack_result(raw):
191191
"""Separate a worker return value into score and metrics.
192192
193-
Objective functions may return a plain float or a (float, dict) tuple.
194-
The backends pass through whatever the function returns, so we unpack
195-
here at the boundary between worker results and the search loop.
193+
Delegates to the centralized unpack_objective_result() which handles
194+
ObjectiveResult, (float, dict) tuples, and plain floats.
196195
"""
197-
if isinstance(raw, tuple):
198-
return raw[0], raw[1]
199-
return raw, {}
196+
return unpack_objective_result(raw)
200197

201198
def _track_evaluation(self, pos, score, eval_time=0, iter_time=0, metrics=None):
202199
"""Record a single evaluation result across all tracking systems."""
@@ -737,9 +734,8 @@ def _init_search(
737734

738735
def _negate(params):
739736
out = _obj(params)
740-
if isinstance(out, tuple):
741-
return -out[0], out[1]
742-
return -out
737+
score, metrics = unpack_objective_result(out)
738+
return (-score, metrics)
743739

744740
self.objective_function = _negate
745741
else:

0 commit comments

Comments
 (0)