Skip to content

Commit 5d71b5d

Browse files
Add reformulate_sos='auto' support to solve() (#595)
* feat: add reformulate_sos='auto' support to solve() - Accept 'auto' as string literal in reformulate_sos parameter (line 1230) - When reformulate_sos='auto' and solver lacks SOS support, silently reformulate - When reformulate_sos='auto' and solver supports SOS natively, pass through without warning - Update error message to mention both True and 'auto' options (line 1424) - Add comprehensive test suite with 5 new test cases covering all scenarios - All 57 SOS reformulation tests pass * fix: improve reformulate_sos validation, DRY up branching, strengthen tests Validate reformulate_sos input early, collapse duplicate True/auto branches, fix docstring type notation, add tests for invalid values and no-SOS no-op, strengthen SOS2 test to actually verify adjacency constraint enforcement. * fix: resolve mypy errors in piecewise and SOS reformulation tests Widen segment types from list[list[float]] to list[Sequence[float]] and add missing type annotations in test fixtures.
1 parent fc5fa6f commit 5d71b5d

File tree

3 files changed

+140
-17
lines changed

3 files changed

+140
-17
lines changed

linopy/model.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ def solve(
12271227
remote: RemoteHandler | OetcHandler = None, # type: ignore
12281228
progress: bool | None = None,
12291229
mock_solve: bool = False,
1230-
reformulate_sos: bool = False,
1230+
reformulate_sos: bool | Literal["auto"] = False,
12311231
**solver_options: Any,
12321232
) -> tuple[str, str]:
12331233
"""
@@ -1297,9 +1297,12 @@ def solve(
12971297
than 10000 variables and constraints.
12981298
mock_solve : bool, optional
12991299
Whether to run a mock solve. This will skip the actual solving. Variables will be set to have dummy values
1300-
reformulate_sos : bool, optional
1300+
reformulate_sos : bool | Literal["auto"], optional
13011301
Whether to automatically reformulate SOS constraints as binary + linear
13021302
constraints for solvers that don't support them natively.
1303+
If True, always reformulates (warns if solver supports SOS natively).
1304+
If "auto", silently reformulates only when the solver lacks SOS support.
1305+
If False, raises if solver doesn't support SOS.
13031306
This uses the Big-M method and requires all SOS variables to have finite bounds.
13041307
Default is False.
13051308
**solver_options : kwargs
@@ -1399,24 +1402,27 @@ def solve(
13991402
f"Solver {solver_name} does not support quadratic problems."
14001403
)
14011404

1405+
if reformulate_sos not in (True, False, "auto"):
1406+
raise ValueError(
1407+
f"Invalid value for reformulate_sos: {reformulate_sos!r}. "
1408+
"Must be True, False, or 'auto'."
1409+
)
1410+
14021411
sos_reform_result = None
14031412
if self.variables.sos:
1404-
if reformulate_sos and not solver_supports(
1405-
solver_name, SolverFeature.SOS_CONSTRAINTS
1406-
):
1413+
supports_sos = solver_supports(solver_name, SolverFeature.SOS_CONSTRAINTS)
1414+
if reformulate_sos in (True, "auto") and not supports_sos:
14071415
logger.info(f"Reformulating SOS constraints for solver {solver_name}")
14081416
sos_reform_result = reformulate_sos_constraints(self)
1409-
elif reformulate_sos and solver_supports(
1410-
solver_name, SolverFeature.SOS_CONSTRAINTS
1411-
):
1417+
elif reformulate_sos is True and supports_sos:
14121418
logger.warning(
14131419
f"Solver {solver_name} supports SOS natively; "
14141420
"reformulate_sos=True is ignored."
14151421
)
1416-
elif not solver_supports(solver_name, SolverFeature.SOS_CONSTRAINTS):
1422+
elif reformulate_sos is False and not supports_sos:
14171423
raise ValueError(
14181424
f"Solver {solver_name} does not support SOS constraints. "
1419-
"Use reformulate_sos=True or a solver that supports SOS (gurobi, cplex)."
1425+
"Use reformulate_sos=True or 'auto', or a solver that supports SOS (gurobi, cplex)."
14201426
)
14211427

14221428
try:

linopy/piecewise.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from __future__ import annotations
99

10-
from collections.abc import Mapping
10+
from collections.abc import Mapping, Sequence
1111
from typing import TYPE_CHECKING, Literal
1212

1313
import numpy as np
@@ -58,7 +58,7 @@ def _dict_to_array(d: dict[str, list[float]], dim: str, bp_dim: str) -> DataArra
5858

5959

6060
def _segments_list_to_array(
61-
values: list[list[float]], bp_dim: str, seg_dim: str
61+
values: list[Sequence[float]], bp_dim: str, seg_dim: str
6262
) -> DataArray:
6363
max_len = max(len(seg) for seg in values)
6464
data = np.full((len(values), max_len), np.nan)
@@ -72,7 +72,7 @@ def _segments_list_to_array(
7272

7373

7474
def _dict_segments_to_array(
75-
d: dict[str, list[list[float]]], dim: str, bp_dim: str, seg_dim: str
75+
d: dict[str, list[Sequence[float]]], dim: str, bp_dim: str, seg_dim: str
7676
) -> DataArray:
7777
parts = []
7878
for key, seg_list in d.items():
@@ -138,7 +138,9 @@ def _resolve_kwargs(
138138

139139

140140
def _resolve_segment_kwargs(
141-
kwargs: dict[str, list[list[float]] | dict[str, list[list[float]]] | DataArray],
141+
kwargs: dict[
142+
str, list[Sequence[float]] | dict[str, list[Sequence[float]]] | DataArray
143+
],
142144
dim: str | None,
143145
bp_dim: str,
144146
seg_dim: str,
@@ -235,13 +237,13 @@ def __call__(
235237

236238
def segments(
237239
self,
238-
values: list[list[float]] | dict[str, list[list[float]]] | None = None,
240+
values: list[Sequence[float]] | dict[str, list[Sequence[float]]] | None = None,
239241
*,
240242
dim: str | None = None,
241243
bp_dim: str = DEFAULT_BREAKPOINT_DIM,
242244
seg_dim: str = DEFAULT_SEGMENT_DIM,
243245
link_dim: str = DEFAULT_LINK_DIM,
244-
**kwargs: list[list[float]] | dict[str, list[list[float]]] | DataArray,
246+
**kwargs: list[Sequence[float]] | dict[str, list[Sequence[float]]] | DataArray,
245247
) -> DataArray:
246248
"""
247249
Create a segmented breakpoint DataArray for disjunctive piecewise constraints.

test/test_sos_reformulation.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from __future__ import annotations
44

5+
import logging
6+
57
import numpy as np
68
import pandas as pd
79
import pytest
810

9-
from linopy import Model, available_solvers
11+
from linopy import Model, Variable, available_solvers
1012
from linopy.constants import SOS_TYPE_ATTR
1113
from linopy.sos_reformulation import (
1214
compute_big_m_values,
@@ -816,3 +818,116 @@ def test_sos1_unsorted_coords(self) -> None:
816818

817819
assert m.objective.value is not None
818820
assert np.isclose(m.objective.value, 3, atol=1e-5)
821+
822+
823+
@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed")
824+
class TestAutoReformulation:
825+
"""Tests for reformulate_sos='auto' functionality."""
826+
827+
@pytest.fixture()
828+
def sos1_model(self) -> tuple[Model, Variable]:
829+
m = Model()
830+
idx = pd.Index([0, 1, 2], name="i")
831+
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
832+
m.add_sos_constraints(x, sos_type=1, sos_dim="i")
833+
m.add_objective(x * np.array([1, 2, 3]), sense="max")
834+
return m, x
835+
836+
def test_auto_reformulates_when_solver_lacks_sos(
837+
self, sos1_model: tuple[Model, Variable]
838+
) -> None:
839+
m, x = sos1_model
840+
m.solve(solver_name="highs", reformulate_sos="auto")
841+
842+
assert np.isclose(x.solution.values[2], 1, atol=1e-5)
843+
assert np.isclose(x.solution.values[0], 0, atol=1e-5)
844+
assert np.isclose(x.solution.values[1], 0, atol=1e-5)
845+
assert m.objective.value is not None
846+
assert np.isclose(m.objective.value, 3, atol=1e-5)
847+
848+
def test_auto_with_sos2(self) -> None:
849+
m = Model()
850+
idx = pd.Index([0, 1, 2, 3], name="i")
851+
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
852+
m.add_sos_constraints(x, sos_type=2, sos_dim="i")
853+
m.add_objective(x * np.array([10, 1, 1, 10]), sense="max")
854+
855+
m.solve(solver_name="highs", reformulate_sos="auto")
856+
857+
assert m.objective.value is not None
858+
nonzero_indices = np.where(np.abs(x.solution.values) > 1e-5)[0]
859+
assert len(nonzero_indices) <= 2
860+
if len(nonzero_indices) == 2:
861+
assert abs(nonzero_indices[1] - nonzero_indices[0]) == 1
862+
assert not np.isclose(m.objective.value, 20, atol=1e-5)
863+
864+
def test_auto_emits_info_no_warning(
865+
self, sos1_model: tuple[Model, Variable], caplog: pytest.LogCaptureFixture
866+
) -> None:
867+
m, _ = sos1_model
868+
869+
with caplog.at_level(logging.INFO):
870+
m.solve(solver_name="highs", reformulate_sos="auto")
871+
872+
assert any("Reformulating SOS" in msg for msg in caplog.messages)
873+
assert not any("supports SOS natively" in msg for msg in caplog.messages)
874+
875+
@pytest.mark.skipif(
876+
"gurobi" not in available_solvers, reason="Gurobi not installed"
877+
)
878+
def test_auto_passes_through_native_sos_without_reformulation(self) -> None:
879+
import gurobipy
880+
881+
m = Model()
882+
idx = pd.Index([0, 1, 2], name="i")
883+
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
884+
m.add_sos_constraints(x, sos_type=1, sos_dim="i")
885+
m.add_objective(x * np.array([1, 2, 3]), sense="max")
886+
887+
try:
888+
m.solve(solver_name="gurobi", reformulate_sos="auto")
889+
except gurobipy.GurobiError as exc:
890+
pytest.skip(f"Gurobi environment unavailable: {exc}")
891+
892+
assert m.objective.value is not None
893+
assert np.isclose(m.objective.value, 3, atol=1e-5)
894+
assert np.isclose(x.solution.values[2], 1, atol=1e-5)
895+
assert np.isclose(x.solution.values[0], 0, atol=1e-5)
896+
assert np.isclose(x.solution.values[1], 0, atol=1e-5)
897+
898+
def test_auto_multidimensional_sos1(self) -> None:
899+
m = Model()
900+
idx_i = pd.Index([0, 1, 2], name="i")
901+
idx_j = pd.Index([0, 1], name="j")
902+
x = m.add_variables(lower=0, upper=1, coords=[idx_i, idx_j], name="x")
903+
m.add_sos_constraints(x, sos_type=1, sos_dim="i")
904+
m.add_objective(x.sum(), sense="max")
905+
906+
m.solve(solver_name="highs", reformulate_sos="auto")
907+
908+
assert m.objective.value is not None
909+
assert np.isclose(m.objective.value, 2, atol=1e-5)
910+
for j in idx_j:
911+
nonzero_count = (np.abs(x.solution.sel(j=j).values) > 1e-5).sum()
912+
assert nonzero_count <= 1
913+
914+
def test_auto_noop_without_sos(self) -> None:
915+
m = Model()
916+
idx = pd.Index([0, 1, 2], name="i")
917+
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
918+
m.add_objective(x.sum(), sense="max")
919+
920+
m.solve(solver_name="highs", reformulate_sos="auto")
921+
922+
assert m.objective.value is not None
923+
assert np.isclose(m.objective.value, 3, atol=1e-5)
924+
925+
def test_invalid_reformulate_sos_value(self) -> None:
926+
m = Model()
927+
idx = pd.Index([0, 1, 2], name="i")
928+
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
929+
m.add_sos_constraints(x, sos_type=1, sos_dim="i")
930+
m.add_objective(x.sum(), sense="max")
931+
932+
with pytest.raises(ValueError, match="Invalid value for reformulate_sos"):
933+
m.solve(solver_name="highs", reformulate_sos="invalid") # type: ignore[arg-type]

0 commit comments

Comments
 (0)