Skip to content

Commit b86bae9

Browse files
feat: Allow tuner warm start and custom constraint objective value (#235)
Exposes `OptunaSearch`'s `points_to_evaluate` for warm-starting the tuner, extends `ConstraintError` to allow specifying a custom fallback objective value when a constraint is breached, and updates the documentation accordingly. # Summary Three related tuner enhancements: 1. **Warm start**: Pass initial parameter guesses to `OptunaSearch` via `points_to_evaluate` — useful in large or heavily constrained search spaces where finding a feasible solution from scratch is slow. 2. **Custom constraint penalty**: `ConstraintError` now accepts an `objective_value` keyword arg, giving users control over the fallback metric value on constraint violation. 3. **Documentation**: The "Tuning a process" tutorial tip box on `ConstraintError` has been updated to mention the custom objective value feature. # Changes - **`OptunaSpec`** (`plugboard-schemas/plugboard_schemas/tune.py`): Added `points_to_evaluate: list[dict[str, Any]] | None = None`, passed through to `OptunaSearch` automatically. - **`ConstraintError`** (`plugboard/exceptions/__init__.py`): Added `objective_value: float | None = None` keyword argument — fully backward compatible. - **`Tuner._build_objective`** (`plugboard/tune/tune.py`): When catching `ConstraintError`, uses `objective_value` from the first exception if set; falls back to ±∞ otherwise. - **`docs/examples/tutorials/tuning-a-process.md`**: Expanded the `ConstraintError` tip box to document the `objective_value` argument with a code example. ```python # Warm start tuner = Tuner( ..., algorithm=OptunaSpec( points_to_evaluate=[{"component.model.arg.lr": 0.01, "component.model.arg.depth": 4}] ), ) # Custom constraint penalty instead of -inf class MyComponent(Component): async def step(self) -> None: if self.value > threshold: raise ConstraintError("Exceeded threshold", objective_value=0.0) ... ``` <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>feat: Allow tuner to accept an initial guess to give it a 'warm start'</issue_title> > <issue_description>### Summary > > The `OptunaSearch` algo supports a `points_to_evaluate` argument, which can be populated with an initial guess. This is particularly useful when exploring large and heavily constrained search spaces, where it otherwise takes a long time for the tuner to find a feasible solution. > > Feature request to to expose this argument, and update docs where necessary. > > ### Example > > _No response_ > > ### Alternatives > > _No response_</issue_description> > > <agent_instructions>Also implement a mechanism to set the objective to a value other than plus or minus infinity when a constraint is breached. This could be done by extending the ConstraintError class, allowing the user to pass in a value that they would like to give the objective. Make sure to preserve backwards compatibility with the current t behaviour. </agent_instructions> > > ## Comments on the Issue (you are @copilot in this section) > > <comments> > </comments> > </details> <!-- START COPILOT CODING AGENT SUFFIX --> - Fixes #231 <!-- START COPILOT CODING AGENT TIPS --> --- ⌨️ Start Copilot coding agent tasks without leaving your editor — available in [VS Code](https://gh.io/cca-vs-code-docs), [Visual Studio](https://gh.io/cca-visual-studio-docs), [JetBrains IDEs](https://gh.io/cca-jetbrains-docs) and [Eclipse](https://gh.io/cca-eclipse-docs). --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: toby-coleman <13170610+toby-coleman@users.noreply.github.com> Co-authored-by: Toby Coleman <toby@tobycoleman.com>
1 parent 8efdcf8 commit b86bae9

6 files changed

Lines changed: 145 additions & 6 deletions

File tree

docs/examples/tutorials/tuning-a-process.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ Running this code will execute an optimisation job and print out information on
6161
!!! tip
6262
You can impose arbitary constraints on variables within a `Process`. In your `step` method you can raise a [`ConstraintError`][plugboard.exceptions.ConstraintError] to indicate to the `Tuner` that a constraint has been breached. This will cause the trial to be stopped, and the optimisation will continue trying to find parameters that don't cause the constraint violation.
6363

64+
By default, the objective is set to ±∞ (depending on the optimisation direction) when a constraint is breached. You can override this by passing an `objective_value` to `ConstraintError`:
65+
```python
66+
raise ConstraintError("Value too high", objective_value=0.0)
67+
```
68+
This is useful when you want violated trials to receive a specific penalty value rather than infinity.
69+
6470
!!! tip
6571
You can optimise over process parameters if you have them in your model. Set `object_type="process"` and `field_type="parameter"` when specifying your tunable parameter.
6672

plugboard-schemas/plugboard_schemas/tune.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ class OptunaSpec(PlugboardBaseModel):
2323
the `ProcessSpec` will be passed to it.
2424
study_name: Optional; The name of the study.
2525
storage: Optional; The storage URI to save the optimisation results to.
26+
points_to_evaluate: Optional; A list of initial parameter configurations to evaluate
27+
first. Each entry is a dict mapping parameter full names to values. Useful for
28+
providing a warm start when exploring large or heavily constrained search spaces.
2629
"""
2730

2831
type: _t.Literal["ray.tune.search.optuna.OptunaSearch"] = "ray.tune.search.optuna.OptunaSearch"
2932
space: str | None = None
3033
study_name: str | None = None
3134
storage: str | None = None
35+
points_to_evaluate: list[dict[str, _t.Any]] | None = None
3236

3337

3438
class BaseFieldSpec(PlugboardBaseModel, ABC):

plugboard/exceptions/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,18 @@ class ValidationError(Exception):
9898

9999

100100
class ConstraintError(Exception):
101-
"""Raised when a constraint is violated."""
102-
103-
pass
101+
"""Raised when a constraint is violated.
102+
103+
Args:
104+
*args: Standard exception arguments.
105+
objective_value: Optional; A custom value to assign to the objective when this constraint
106+
is violated. If not provided, the tuner will assign plus or minus infinity depending
107+
on the optimisation direction.
108+
"""
109+
110+
def __init__(self, *args: object, objective_value: float | None = None) -> None:
111+
super().__init__(*args)
112+
self.objective_value = objective_value
104113

105114

106115
class ProcessStatusError(Exception):

plugboard/tune/tune.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,14 +289,19 @@ def fn(config: dict[str, _t.Any]) -> dict[str, _t.Any]: # pragma: no cover
289289
result = {
290290
obj.full_name: self._get_objective(process, obj) for obj in self._objective
291291
}
292-
except* ConstraintError as e:
292+
except* ConstraintError as eg:
293293
modes = self._mode if isinstance(self._mode, list) else [self._mode]
294294
self._logger.warning(
295295
"Constraint violated during optimisation, stopping early",
296-
constraint_error=str(e),
296+
constraint_error=str(eg),
297297
)
298+
first_exc = _t.cast(ConstraintError, eg.exceptions[0]) if eg.exceptions else None
298299
result = {
299-
obj.full_name: math.inf if mode == "min" else -math.inf
300+
obj.full_name: (
301+
first_exc.objective_value
302+
if first_exc is not None and first_exc.objective_value is not None
303+
else (math.inf if mode == "min" else -math.inf)
304+
)
300305
for obj, mode in zip(self._objective, modes)
301306
}
302307

tests/integration/test_tuner.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ async def step(self) -> None:
3434
await super().step()
3535

3636

37+
class ConstrainedBWithObjectiveValue(B):
38+
"""Component with a constraint that provides a custom objective value."""
39+
40+
async def step(self) -> None:
41+
"""Override step to apply a constraint with a custom objective value."""
42+
if self.in_1 > 10:
43+
raise ConstraintError("Input must not be greater than 10", objective_value=0.0)
44+
await super().step()
45+
46+
3747
class DynamicListComponent(ComponentTestHelper):
3848
"""Component with a dynamic list parameter for tuning."""
3949

@@ -294,6 +304,55 @@ async def test_tune_with_constraint(config: dict, ray_ctx: None) -> None:
294304
)
295305

296306

307+
@pytest.mark.tuner
308+
@pytest.mark.asyncio
309+
async def test_tune_with_constraint_objective_value(config: dict, ray_ctx: None) -> None:
310+
"""Tests that a ConstraintError with objective_value uses that value instead of infinity."""
311+
spec = ConfigSpec.model_validate(config)
312+
process_spec = spec.plugboard.process
313+
# Replace component B with a version that provides a custom objective value on constraint
314+
process_spec.args.components[
315+
1
316+
].type = "tests.integration.test_tuner.ConstrainedBWithObjectiveValue"
317+
tuner = Tuner(
318+
objective=ObjectiveSpec(
319+
object_type="component",
320+
object_name="c",
321+
field_type="field",
322+
field_name="in_1",
323+
),
324+
parameters=[
325+
IntParameterSpec(
326+
object_type="component",
327+
object_name="a",
328+
field_type="arg",
329+
field_name="iters",
330+
lower=5,
331+
upper=15,
332+
)
333+
],
334+
num_samples=12,
335+
mode="max",
336+
max_concurrent=2,
337+
algorithm=OptunaSpec(),
338+
)
339+
best_result = tuner.run(
340+
spec=process_spec,
341+
)
342+
result = tuner.result_grid
343+
# There must be no failed trials
344+
assert not any(t.error for t in result)
345+
# Optimum must be at or below 10 (constraint threshold)
346+
assert best_result.metrics["component.c.field.in_1"] <= 10
347+
# When constraint is violated the custom objective_value (0.0) must be used, not -inf
348+
# The constraint raises when in_1 > 10; in_1 = iters - 1, so iters > 11 violates it
349+
assert all(
350+
t.metrics["component.c.field.in_1"] == 0.0
351+
for t in result
352+
if t.config["component.a.arg.iters"] > 11
353+
)
354+
355+
297356
@pytest.mark.tuner
298357
@pytest.mark.asyncio
299358
@pytest.mark.parametrize("space_func", [custom_space, custom_space_with_process_spec])

tests/unit/test_tuner.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
from ray.tune.search.optuna import OptunaSearch
1010

11+
from plugboard.exceptions import ConstraintError
1112
from plugboard.schemas import (
1213
CategoricalParameterSpec,
1314
ConfigSpec,
@@ -131,3 +132,58 @@ def test_optuna_storage_uri_conversion(temp_dir: str) -> None:
131132
tuner.run(spec=MagicMock())
132133
passed_alg = mock_tuner_cls.call_args.kwargs["tune_config"].search_alg
133134
assert isinstance(passed_alg, OptunaSearch)
135+
136+
137+
def test_optuna_points_to_evaluate(config: dict) -> None:
138+
"""Test that points_to_evaluate is passed through to the OptunaSearch algorithm."""
139+
spec = ConfigSpec.model_validate(config)
140+
process_spec = spec.plugboard.process
141+
points = [{"component.a.arg.x": 7, "component.a.arg.y": 0.3}]
142+
tuner = Tuner(
143+
objective=ObjectiveSpec(
144+
object_type="component",
145+
object_name="c",
146+
field_type="field",
147+
field_name="in_1",
148+
),
149+
parameters=[
150+
IntParameterSpec(
151+
object_type="component",
152+
object_name="a",
153+
field_type="arg",
154+
field_name="x",
155+
lower=6,
156+
upper=8,
157+
),
158+
FloatParameterSpec(
159+
object_type="component",
160+
object_name="a",
161+
field_type="arg",
162+
field_name="y",
163+
lower=0.1,
164+
upper=0.5,
165+
),
166+
],
167+
num_samples=3,
168+
mode="max",
169+
algorithm=OptunaSpec(points_to_evaluate=points),
170+
)
171+
with patch("ray.tune.Tuner") as mock_tuner_cls:
172+
tuner.run(spec=process_spec)
173+
search_alg = mock_tuner_cls.call_args.kwargs["tune_config"].search_alg
174+
# The underlying OptunaSearch must have received the points_to_evaluate
175+
assert isinstance(search_alg, OptunaSearch)
176+
assert search_alg._points_to_evaluate == points
177+
178+
179+
def test_constraint_error_objective_value() -> None:
180+
"""Test that ConstraintError stores an optional objective_value."""
181+
# Default (no objective_value): backward compatible usage
182+
err = ConstraintError("constraint violated")
183+
assert err.objective_value is None
184+
assert str(err) == "constraint violated"
185+
186+
# With objective_value keyword argument
187+
err_with_value = ConstraintError("constraint violated", objective_value=5.0)
188+
assert err_with_value.objective_value == 5.0
189+
assert str(err_with_value) == "constraint violated"

0 commit comments

Comments
 (0)