|
21 | 21 | ScalarizedOutcomeConstraint, |
22 | 22 | ) |
23 | 23 | from ax.core.types import ComparisonOp |
24 | | -from ax.exceptions.core import UserInputError |
| 24 | +from ax.exceptions.core import UnsupportedError, UserInputError |
25 | 25 | from ax.utils.common.testutils import TestCase |
26 | 26 | from pyre_extensions import assert_is_instance |
27 | 27 |
|
28 | 28 |
|
29 | 29 | OC_STR = ( |
30 | 30 | "OptimizationConfig(" |
31 | | - 'objective=Objective(expression="m1"), ' |
| 31 | + 'objectives=[Objective(expression="m1")], ' |
32 | 32 | "outcome_constraints=[OutcomeConstraint(m3 >= -0.25), " |
33 | 33 | "OutcomeConstraint(m4 <= 0.25), " |
34 | 34 | "ScalarizedOutcomeConstraint(0.5*m3 + 0.5*m4 >= 0.9975 * baseline)])" |
@@ -271,6 +271,111 @@ def test_CloneWithArgs(self) -> None: |
271 | 271 | ) |
272 | 272 |
|
273 | 273 |
|
| 274 | +class OptimizationConfigObjectivesListTest(TestCase): |
| 275 | + """Tests for the new `OptimizationConfig(objectives=[...])` construction path.""" |
| 276 | + |
| 277 | + def setUp(self) -> None: |
| 278 | + super().setUp() |
| 279 | + self.metrics = { |
| 280 | + "m1": Metric(name="m1"), |
| 281 | + "m2": Metric(name="m2"), |
| 282 | + "m3": Metric(name="m3"), |
| 283 | + } |
| 284 | + self.sig = {m: m for m in self.metrics} |
| 285 | + self.obj1 = Objective(expression="m1", metric_name_to_signature=self.sig) |
| 286 | + self.obj2 = Objective(expression="-m2", metric_name_to_signature=self.sig) |
| 287 | + self.scalarized_obj = Objective( |
| 288 | + expression="2*m1 + m2", metric_name_to_signature=self.sig |
| 289 | + ) |
| 290 | + |
| 291 | + def test_objectives_kwarg_construction(self) -> None: |
| 292 | + """Test single and multi-objective construction via objectives kwarg.""" |
| 293 | + # Single objective |
| 294 | + config = OptimizationConfig(objectives=[self.obj1]) |
| 295 | + self.assertEqual(config.objectives, [self.obj1]) |
| 296 | + self.assertEqual(config.objective, self.obj1) |
| 297 | + self.assertFalse(config.is_moo_problem) |
| 298 | + |
| 299 | + # Multi-objective |
| 300 | + config = OptimizationConfig(objectives=[self.obj1, self.obj2]) |
| 301 | + self.assertEqual(config.objectives, [self.obj1, self.obj2]) |
| 302 | + self.assertTrue(config.is_moo_problem) |
| 303 | + with self.assertRaisesRegex(UnsupportedError, "multiple objectives"): |
| 304 | + config.objective |
| 305 | + |
| 306 | + def test_objectives_kwarg_metric_aggregation(self) -> None: |
| 307 | + """Test metric_names, metric_name_to_signature, metric_signatures.""" |
| 308 | + constraint = OutcomeConstraint( |
| 309 | + expression="m3 >= 0.5", metric_name_to_signature=self.sig |
| 310 | + ) |
| 311 | + config = OptimizationConfig( |
| 312 | + objectives=[self.obj1, self.obj2], |
| 313 | + outcome_constraints=[constraint], |
| 314 | + ) |
| 315 | + self.assertEqual(config.metric_names, {"m1", "m2", "m3"}) |
| 316 | + self.assertEqual( |
| 317 | + config.metric_name_to_signature, {"m1": "m1", "m2": "m2", "m3": "m3"} |
| 318 | + ) |
| 319 | + self.assertEqual(config.metric_signatures, {"m1", "m2", "m3"}) |
| 320 | + |
| 321 | + def test_objectives_kwarg_validation(self) -> None: |
| 322 | + """Test validation errors for objectives kwarg.""" |
| 323 | + with self.subTest("mutual_exclusivity"): |
| 324 | + with self.assertRaisesRegex(UserInputError, "Cannot specify both"): |
| 325 | + OptimizationConfig(objective=self.obj1, objectives=[self.obj1]) |
| 326 | + |
| 327 | + with self.subTest("neither_specified"): |
| 328 | + with self.assertRaisesRegex(UserInputError, "Must specify either"): |
| 329 | + OptimizationConfig() |
| 330 | + |
| 331 | + with self.subTest("empty_list"): |
| 332 | + with self.assertRaisesRegex(UserInputError, "must not be empty"): |
| 333 | + OptimizationConfig(objectives=[]) |
| 334 | + |
| 335 | + with self.subTest("multi_objective_expression"): |
| 336 | + multi_obj = Objective( |
| 337 | + expression="m1, -m2", metric_name_to_signature=self.sig |
| 338 | + ) |
| 339 | + with self.assertRaisesRegex(ValueError, "single or scalarized"): |
| 340 | + OptimizationConfig(objectives=[multi_obj]) |
| 341 | + |
| 342 | + with self.subTest("duplicate_metric_names"): |
| 343 | + obj_dup = Objective(expression="m1", metric_name_to_signature=self.sig) |
| 344 | + with self.assertRaisesRegex(UserInputError, "appears in multiple"): |
| 345 | + OptimizationConfig(objectives=[self.obj1, obj_dup]) |
| 346 | + |
| 347 | + def test_objectives_kwarg_clone_and_repr(self) -> None: |
| 348 | + """Test clone, clone_with_args, and repr for objectives-list configs.""" |
| 349 | + config = OptimizationConfig(objectives=[self.obj1, self.obj2]) |
| 350 | + |
| 351 | + # clone preserves objectives |
| 352 | + cloned = config.clone() |
| 353 | + self.assertEqual(len(cloned.objectives), 2) |
| 354 | + self.assertEqual(cloned.objectives[0].expression, "m1") |
| 355 | + self.assertEqual(cloned.objectives[1].expression, "-m2") |
| 356 | + self.assertTrue(cloned.is_moo_problem) |
| 357 | + |
| 358 | + # clone_with_args(objective=) replaces the list with a single objective |
| 359 | + cloned = config.clone_with_args(objective=self.obj1) |
| 360 | + self.assertEqual(len(cloned.objectives), 1) |
| 361 | + self.assertFalse(cloned.is_moo_problem) |
| 362 | + |
| 363 | + # clone_with_args(objectives=) replaces the list |
| 364 | + obj3 = Objective(expression="m3", metric_name_to_signature=self.sig) |
| 365 | + cloned = config.clone_with_args(objectives=[self.obj1, obj3]) |
| 366 | + self.assertEqual(len(cloned.objectives), 2) |
| 367 | + self.assertEqual(cloned.objectives[1].expression, "m3") |
| 368 | + |
| 369 | + # objective= and objectives= are mutually exclusive in clone_with_args |
| 370 | + with self.assertRaisesRegex(UserInputError, "Cannot specify both"): |
| 371 | + config.clone_with_args(objective=self.obj1, objectives=[self.obj1]) |
| 372 | + |
| 373 | + # repr always uses "objectives=" |
| 374 | + self.assertIn("objectives=", repr(config)) |
| 375 | + single_config = OptimizationConfig(objectives=[self.obj1]) |
| 376 | + self.assertIn("objectives=", repr(single_config)) |
| 377 | + |
| 378 | + |
274 | 379 | class MultiObjectiveOptimizationConfigTest(TestCase): |
275 | 380 | def setUp(self) -> None: |
276 | 381 | super().setUp() |
|
0 commit comments