Skip to content

Commit 9e384e0

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add objectives: list[Objective] to OptimizationConfig (#5150)
Summary: Part of the Restrict Objective to Single/Scalarized & Simplify OptimizationConfig design (see design doc: https://docs.google.com/document/d/1EGQYmBjiNGtYapXu1RLHEBdA5Yz2c7q17acX3es0yV8/edit). This is Diff 1 of the stack: enables the new `OptimizationConfig(objectives=[...])` construction path without breaking any existing code. Changes: - `OptimizationConfig.__init__` and `clone_with_args` are now keyword-only across `OptimizationConfig`, `MOOC`, and `PreferenceOptimizationConfig`. All positional callers updated. - New kwarg `objectives: list[Objective] | None = None`, mutually exclusive with `objective`, on both `__init__` and `clone_with_args`. - Internally stores `self._objectives: list[Objective]` (both paths). - New `objectives` property returns the list. - `objective` property raises `UnsupportedError` if `len > 1`. - `is_moo_problem` property: True when multiple objectives or legacy multi-objective expression. - `metric_names`, `metric_name_to_signature`, `metric_signatures` aggregate across all objectives + constraints. - `__repr__` always uses `objectives=`. - JSON storage: encoder uses `objectives` key; decoder has backward compat to convert old `objective` key to `objectives` list. - SQA storage: encoder iterates `objectives` to encode each one; decoder collects multiple OBJECTIVE rows and reconstructs `OptimizationConfig(objectives=...)` when `len > 1`. - Validation ensures no duplicate metrics across objectives and no multi-objective expressions in individual list elements. Differential Revision: D99387020
1 parent 3712dd6 commit 9e384e0

14 files changed

Lines changed: 364 additions & 99 deletions

File tree

ax/adapter/tests/test_torch_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,7 @@ def test_pairwise_preference_generator(self) -> None:
12051205
surrogate=surrogate,
12061206
),
12071207
optimization_config=OptimizationConfig(
1208-
Objective(
1208+
objective=Objective(
12091209
metric=Metric(Keys.PAIRWISE_PREFERENCE_QUERY.value),
12101210
minimize=False,
12111211
)

ax/core/optimization_config.py

Lines changed: 184 additions & 65 deletions
Large diffs are not rendered by default.

ax/core/tests/test_multi_type_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_setting_opt_config(self) -> None:
154154
m3 = BraninMetric("m3", ["x1", "x2"])
155155
self.experiment.add_tracking_metric(m3)
156156
self.experiment.optimization_config = OptimizationConfig(
157-
Objective(metric=m3, minimize=True)
157+
objective=Objective(metric=m3, minimize=True)
158158
)
159159
self.assertDictEqual(
160160
self.experiment._metric_to_trial_type,

ax/core/tests/test_optimization_config.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
ScalarizedOutcomeConstraint,
2222
)
2323
from ax.core.types import ComparisonOp
24-
from ax.exceptions.core import UserInputError
24+
from ax.exceptions.core import UnsupportedError, UserInputError
2525
from ax.utils.common.testutils import TestCase
2626
from pyre_extensions import assert_is_instance
2727

2828

2929
OC_STR = (
3030
"OptimizationConfig("
31-
'objective=Objective(expression="m1"), '
31+
'objectives=[Objective(expression="m1")], '
3232
"outcome_constraints=[OutcomeConstraint(m3 >= -0.25), "
3333
"OutcomeConstraint(m4 <= 0.25), "
3434
"ScalarizedOutcomeConstraint(0.5*m3 + 0.5*m4 >= 0.9975 * baseline)])"
@@ -271,6 +271,111 @@ def test_CloneWithArgs(self) -> None:
271271
)
272272

273273

274+
class OptimizationConfigObjectivesListTest(TestCase):
275+
"""Tests for the new `OptimizationConfig(objectives=[...])` construction path."""
276+
277+
def setUp(self) -> None:
278+
super().setUp()
279+
self.metrics = {
280+
"m1": Metric(name="m1"),
281+
"m2": Metric(name="m2"),
282+
"m3": Metric(name="m3"),
283+
}
284+
self.sig = {m: m for m in self.metrics}
285+
self.obj1 = Objective(expression="m1", metric_name_to_signature=self.sig)
286+
self.obj2 = Objective(expression="-m2", metric_name_to_signature=self.sig)
287+
self.scalarized_obj = Objective(
288+
expression="2*m1 + m2", metric_name_to_signature=self.sig
289+
)
290+
291+
def test_objectives_kwarg_construction(self) -> None:
292+
"""Test single and multi-objective construction via objectives kwarg."""
293+
# Single objective
294+
config = OptimizationConfig(objectives=[self.obj1])
295+
self.assertEqual(config.objectives, [self.obj1])
296+
self.assertEqual(config.objective, self.obj1)
297+
self.assertFalse(config.is_moo_problem)
298+
299+
# Multi-objective
300+
config = OptimizationConfig(objectives=[self.obj1, self.obj2])
301+
self.assertEqual(config.objectives, [self.obj1, self.obj2])
302+
self.assertTrue(config.is_moo_problem)
303+
with self.assertRaisesRegex(UnsupportedError, "multiple objectives"):
304+
config.objective
305+
306+
def test_objectives_kwarg_metric_aggregation(self) -> None:
307+
"""Test metric_names, metric_name_to_signature, metric_signatures."""
308+
constraint = OutcomeConstraint(
309+
expression="m3 >= 0.5", metric_name_to_signature=self.sig
310+
)
311+
config = OptimizationConfig(
312+
objectives=[self.obj1, self.obj2],
313+
outcome_constraints=[constraint],
314+
)
315+
self.assertEqual(config.metric_names, {"m1", "m2", "m3"})
316+
self.assertEqual(
317+
config.metric_name_to_signature, {"m1": "m1", "m2": "m2", "m3": "m3"}
318+
)
319+
self.assertEqual(config.metric_signatures, {"m1", "m2", "m3"})
320+
321+
def test_objectives_kwarg_validation(self) -> None:
322+
"""Test validation errors for objectives kwarg."""
323+
with self.subTest("mutual_exclusivity"):
324+
with self.assertRaisesRegex(UserInputError, "Cannot specify both"):
325+
OptimizationConfig(objective=self.obj1, objectives=[self.obj1])
326+
327+
with self.subTest("neither_specified"):
328+
with self.assertRaisesRegex(UserInputError, "Must specify either"):
329+
OptimizationConfig()
330+
331+
with self.subTest("empty_list"):
332+
with self.assertRaisesRegex(UserInputError, "must not be empty"):
333+
OptimizationConfig(objectives=[])
334+
335+
with self.subTest("multi_objective_expression"):
336+
multi_obj = Objective(
337+
expression="m1, -m2", metric_name_to_signature=self.sig
338+
)
339+
with self.assertRaisesRegex(ValueError, "single or scalarized"):
340+
OptimizationConfig(objectives=[multi_obj])
341+
342+
with self.subTest("duplicate_metric_names"):
343+
obj_dup = Objective(expression="m1", metric_name_to_signature=self.sig)
344+
with self.assertRaisesRegex(UserInputError, "appears in multiple"):
345+
OptimizationConfig(objectives=[self.obj1, obj_dup])
346+
347+
def test_objectives_kwarg_clone_and_repr(self) -> None:
348+
"""Test clone, clone_with_args, and repr for objectives-list configs."""
349+
config = OptimizationConfig(objectives=[self.obj1, self.obj2])
350+
351+
# clone preserves objectives
352+
cloned = config.clone()
353+
self.assertEqual(len(cloned.objectives), 2)
354+
self.assertEqual(cloned.objectives[0].expression, "m1")
355+
self.assertEqual(cloned.objectives[1].expression, "-m2")
356+
self.assertTrue(cloned.is_moo_problem)
357+
358+
# clone_with_args(objective=) replaces the list with a single objective
359+
cloned = config.clone_with_args(objective=self.obj1)
360+
self.assertEqual(len(cloned.objectives), 1)
361+
self.assertFalse(cloned.is_moo_problem)
362+
363+
# clone_with_args(objectives=) replaces the list
364+
obj3 = Objective(expression="m3", metric_name_to_signature=self.sig)
365+
cloned = config.clone_with_args(objectives=[self.obj1, obj3])
366+
self.assertEqual(len(cloned.objectives), 2)
367+
self.assertEqual(cloned.objectives[1].expression, "m3")
368+
369+
# objective= and objectives= are mutually exclusive in clone_with_args
370+
with self.assertRaisesRegex(UserInputError, "Cannot specify both"):
371+
config.clone_with_args(objective=self.obj1, objectives=[self.obj1])
372+
373+
# repr always uses "objectives="
374+
self.assertIn("objectives=", repr(config))
375+
single_config = OptimizationConfig(objectives=[self.obj1])
376+
self.assertIn("objectives=", repr(single_config))
377+
378+
274379
class MultiObjectiveOptimizationConfigTest(TestCase):
275380
def setUp(self) -> None:
276381
super().setUp()

ax/orchestration/tests/test_orchestrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2727,7 +2727,7 @@ def test_generate_candidates_does_not_generate_if_missing_data(self) -> None:
27272727
)
27282728
self.branin_experiment.add_tracking_metric(custom_metric)
27292729
self.branin_experiment.optimization_config = OptimizationConfig(
2730-
Objective(
2730+
objective=Objective(
27312731
metric=CustomTestMetric(
27322732
name="custom_test_metric", test_attribute="test"
27332733
),
@@ -2974,7 +2974,7 @@ def setUp(self) -> None:
29742974
self.branin_experiment_no_impl_runner_or_metrics = MultiTypeExperiment(
29752975
search_space=get_branin_search_space(),
29762976
optimization_config=OptimizationConfig(
2977-
Objective(metric=Metric(name="branin"), minimize=True)
2977+
objective=Objective(metric=Metric(name="branin"), minimize=True)
29782978
),
29792979
default_trial_type="type1",
29802980
default_runner=None,

ax/service/tests/test_best_point_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def test_best_raw_objective_point_scalarized(self) -> None:
615615
exp = get_branin_experiment()
616616
gs = choose_generation_strategy_legacy(search_space=exp.search_space)
617617
exp.optimization_config = OptimizationConfig(
618-
ScalarizedObjective(metrics=[get_branin_metric()], minimize=True)
618+
objective=ScalarizedObjective(metrics=[get_branin_metric()], minimize=True)
619619
)
620620
with self.assertRaisesRegex(ValueError, "Cannot identify best "):
621621
get_best_raw_objective_point_with_trial_index(exp)
@@ -637,7 +637,7 @@ def test_best_raw_objective_point_scalarized_multi(self) -> None:
637637
exp = get_branin_experiment()
638638
gs = choose_generation_strategy_legacy(search_space=exp.search_space)
639639
exp.optimization_config = OptimizationConfig(
640-
ScalarizedObjective(
640+
objective=ScalarizedObjective(
641641
metrics=[get_branin_metric(), get_branin_metric(lower_is_better=False)],
642642
weights=[0.1, -0.9],
643643
minimize=True,
@@ -1037,7 +1037,7 @@ def test_best_parameters_from_model_predictions_scalarized(self) -> None:
10371037
)
10381038
exp.add_tracking_metric(metric2)
10391039
exp.optimization_config = OptimizationConfig(
1040-
ScalarizedObjective(
1040+
objective=ScalarizedObjective(
10411041
metrics=[metric1, metric2],
10421042
weights=[0.5, 0.5],
10431043
minimize=True,

ax/service/tests/test_report_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,12 @@ def _test_get_standard_plots_moo_relative_constraints(
449449
names = obj.metric_names
450450
# Create a new Objective rather than mutating _expression_str to
451451
# avoid stale _parsed cached_property.
452-
none_throws(exp.optimization_config)._objective = Objective(
453-
expression=f"{names[0]}, -{names[1]}",
454-
metric_name_to_signature={n: n for n in names},
455-
)
452+
none_throws(exp.optimization_config)._objectives = [
453+
Objective(
454+
expression=f"{names[0]}, -{names[1]}",
455+
metric_name_to_signature={n: n for n in names},
456+
)
457+
]
456458
exp.get_metric(names[0]).lower_is_better = False
457459
assert_is_instance(
458460
exp.optimization_config, MultiObjectiveOptimizationConfig
@@ -494,10 +496,12 @@ def test_get_standard_plots_moo_no_objective_thresholds(self) -> None:
494496
# first objective to maximize, second to minimize
495497
obj = none_throws(exp.optimization_config).objective
496498
names = obj.metric_names
497-
none_throws(exp.optimization_config)._objective = Objective(
498-
expression=f"{names[0]}, -{names[1]}",
499-
metric_name_to_signature={n: n for n in names},
500-
)
499+
none_throws(exp.optimization_config)._objectives = [
500+
Objective(
501+
expression=f"{names[0]}, -{names[1]}",
502+
metric_name_to_signature={n: n for n in names},
503+
)
504+
]
501505
exp.trials[0].run()
502506
plots = get_standard_plots(
503507
experiment=exp,

ax/storage/json_store/decoder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,13 @@ def object_from_json(
351351
object_json = _sanitize_inputs_to_surrogate_spec(object_json=object_json)
352352
if isclass(_class) and issubclass(_class, OptimizationConfig):
353353
object_json.pop("risk_measure", None) # Deprecated.
354+
# Backward compat: old JSON uses "objective", new uses "objectives".
355+
if (
356+
_class is OptimizationConfig
357+
and "objective" in object_json
358+
and "objectives" not in object_json
359+
):
360+
object_json["objectives"] = [object_json.pop("objective")]
354361
return ax_class_from_json_dict(
355362
_class=_class, object_json=object_json, **vars(registry_kwargs)
356363
)

ax/storage/json_store/encoders.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def optimization_config_to_dict(
383383
"""Convert Ax optimization config to a dictionary."""
384384
return {
385385
"__type": optimization_config.__class__.__name__,
386-
"objective": optimization_config.objective,
386+
"objectives": optimization_config.objectives,
387387
"outcome_constraints": optimization_config.outcome_constraints,
388388
"pruning_target_parameterization": (
389389
optimization_config.pruning_target_parameterization
@@ -782,16 +782,17 @@ def _build_opt_config_dict(
782782
will then recursively encode them via ``metric_to_dict``, capturing the
783783
full metric type.
784784
"""
785-
objective_dict = _build_objective_dict(
786-
objective=opt_config.objective, experiment_metrics=experiment_metrics
787-
)
785+
objective_dicts = [
786+
_build_objective_dict(objective=obj, experiment_metrics=experiment_metrics)
787+
for obj in opt_config.objectives
788+
]
788789
constraint_dicts = [
789790
_build_constraint_dict(constraint=c, experiment_metrics=experiment_metrics)
790791
for c in opt_config.outcome_constraints
791792
]
792793
result: dict[str, Any] = {
793794
"__type": opt_config.__class__.__name__,
794-
"objective": objective_dict,
795+
"objectives": objective_dicts,
795796
"outcome_constraints": constraint_dicts,
796797
"pruning_target_parameterization": opt_config.pruning_target_parameterization,
797798
}

ax/storage/json_store/tests/test_json_store.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
get_metric,
138138
get_mll_type,
139139
get_model_type,
140+
get_moo_optimization_config,
140141
get_multi_objective,
141142
get_multi_objective_optimization_config,
142143
get_multi_type_experiment,
@@ -380,6 +381,7 @@
380381
("Objective", get_objective),
381382
("ObjectiveThreshold", get_objective_threshold),
382383
("OptimizationConfig", get_optimization_config),
384+
("OptimizationConfig", get_moo_optimization_config),
383385
("OrEarlyStoppingStrategy", get_or_early_stopping_strategy),
384386
("OrderConstraint", get_order_constraint),
385387
("OutcomeConstraint", get_outcome_constraint),

0 commit comments

Comments
 (0)