Skip to content

Commit 8cac1d7

Browse files
FBumannclaude
andcommitted
feat: generalize disjunctive formulation to N variables
Refactor _add_disjunctive to use the same stacked N-variable pattern as _add_continuous. Removes the 2-variable restriction — disjunctive now supports any number of (expression, breakpoints) pairs with a single unified link constraint. - Remove separate x_link/y_link in favor of single _link with _pwl_var dim - Remove PWL_Y_LINK_SUFFIX import (no longer needed) - Add test for 3-variable disjunctive Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e23f934 commit 8cac1d7

2 files changed

Lines changed: 80 additions & 53 deletions

File tree

linopy/piecewise.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
PWL_LAMBDA_SUFFIX,
3232
PWL_SELECT_SUFFIX,
3333
PWL_X_LINK_SUFFIX,
34-
PWL_Y_LINK_SUFFIX,
3534
SEGMENT_DIM,
3635
)
3736

@@ -682,21 +681,17 @@ def add_piecewise_constraints(
682681
active_expr = _to_linexpr(active) if active is not None else None
683682

684683
if disjunctive:
685-
# Disjunctive only supports 2-variable for now
686-
if len(coerced) != 2:
684+
if method == "incremental":
687685
raise ValueError(
688-
"Disjunctive piecewise constraints currently support "
689-
"exactly 2 (expression, breakpoints) pairs."
686+
"Incremental method is not supported for disjunctive constraints"
690687
)
691688
return _add_disjunctive(
692689
model,
693690
name,
694-
lin_exprs[0],
695-
lin_exprs[1],
696-
bp_list[0],
697-
bp_list[1],
691+
lin_exprs,
692+
bp_list,
693+
link_coords,
698694
bp_mask,
699-
method,
700695
active_expr,
701696
)
702697

@@ -901,68 +896,81 @@ def _add_incremental(
901896
def _add_disjunctive(
902897
model: Model,
903898
name: str,
904-
x_expr: LinearExpression,
905-
y_expr: LinearExpression,
906-
x_points: DataArray,
907-
y_points: DataArray,
908-
mask: DataArray | None,
909-
method: str,
899+
lin_exprs: list[LinearExpression],
900+
bp_list: list[DataArray],
901+
link_coords: list[str],
902+
bp_mask: DataArray | None,
910903
active: LinearExpression | None = None,
911904
) -> Constraint:
912-
"""Handle disjunctive piecewise equality constraints (2-variable only)."""
913-
if method == "incremental":
914-
raise ValueError(
915-
"Incremental method is not supported for disjunctive constraints"
916-
)
905+
"""Disjunctive SOS2 formulation for N-variable piecewise equality."""
906+
from linopy.expressions import LinearExpression
917907

918-
_validate_numeric_breakpoint_coords(x_points)
919-
if not _has_trailing_nan_only(x_points):
908+
link_dim = "_pwl_var"
909+
stacked_bp = _stack_along_link(bp_list, link_coords, link_dim)
910+
911+
_validate_numeric_breakpoint_coords(stacked_bp)
912+
if not _has_trailing_nan_only(stacked_bp):
920913
raise ValueError(
921914
"Disjunctive SOS2 does not support non-trailing NaN breakpoints. "
922915
"NaN values must only appear at the end of the breakpoint sequence."
923916
)
924917

925-
binary_name = f"{name}{PWL_BINARY_SUFFIX}"
926-
select_name = f"{name}{PWL_SELECT_SUFFIX}"
927-
lambda_name = f"{name}{PWL_LAMBDA_SUFFIX}"
928-
convex_name = f"{name}{PWL_CONVEX_SUFFIX}"
929-
x_link_name = f"{name}{PWL_X_LINK_SUFFIX}"
930-
y_link_name = f"{name}{PWL_Y_LINK_SUFFIX}"
918+
# Stack expressions along link dimension
919+
stacked_data = _stack_along_link(
920+
[e.data for e in lin_exprs], link_coords, link_dim
921+
)
922+
target_expr = LinearExpression(stacked_data, model)
931923

932-
extra = _var_coords_from(x_points, exclude={BREAKPOINT_DIM, SEGMENT_DIM})
924+
# Compute stacked mask
925+
stacked_mask = None
926+
if bp_mask is not None:
927+
stacked_mask = _stack_along_link(
928+
[bp_mask] * len(link_coords), link_coords, link_dim
929+
)
930+
931+
dim = BREAKPOINT_DIM
932+
extra = _var_coords_from(stacked_bp, exclude={dim, SEGMENT_DIM, link_dim})
933933
lambda_coords = extra + [
934-
pd.Index(x_points.coords[SEGMENT_DIM].values, name=SEGMENT_DIM),
935-
pd.Index(x_points.coords[BREAKPOINT_DIM].values, name=BREAKPOINT_DIM),
934+
pd.Index(stacked_bp.coords[SEGMENT_DIM].values, name=SEGMENT_DIM),
935+
pd.Index(stacked_bp.coords[dim].values, name=dim),
936936
]
937937
binary_coords = extra + [
938-
pd.Index(x_points.coords[SEGMENT_DIM].values, name=SEGMENT_DIM),
938+
pd.Index(stacked_bp.coords[SEGMENT_DIM].values, name=SEGMENT_DIM),
939939
]
940940

941-
binary_mask = mask.any(dim=BREAKPOINT_DIM) if mask is not None else None
941+
# Masks
942+
lambda_mask = None
943+
binary_mask = None
944+
if stacked_mask is not None:
945+
# Aggregate across link_dim — all variables must be valid
946+
agg_mask = stacked_mask.all(dim=link_dim)
947+
lambda_mask = agg_mask
948+
binary_mask = agg_mask.any(dim=dim)
949+
950+
binary_name = f"{name}{PWL_BINARY_SUFFIX}"
951+
select_name = f"{name}{PWL_SELECT_SUFFIX}"
952+
lambda_name = f"{name}{PWL_LAMBDA_SUFFIX}"
953+
convex_name = f"{name}{PWL_CONVEX_SUFFIX}"
954+
link_name = f"{name}{PWL_X_LINK_SUFFIX}"
942955

943956
binary_var = model.add_variables(
944957
binary=True, coords=binary_coords, name=binary_name, mask=binary_mask
945958
)
946959

947960
rhs = active if active is not None else 1
948-
select_con = model.add_constraints(
961+
model.add_constraints(
949962
binary_var.sum(dim=SEGMENT_DIM) == rhs, name=select_name
950963
)
951964

952965
lambda_var = model.add_variables(
953-
lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=mask
966+
lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=lambda_mask
954967
)
955968

956-
model.add_sos_constraints(lambda_var, sos_type=2, sos_dim=BREAKPOINT_DIM)
969+
model.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim)
957970

958971
model.add_constraints(
959-
lambda_var.sum(dim=BREAKPOINT_DIM) == binary_var, name=convex_name
972+
lambda_var.sum(dim=dim) == binary_var, name=convex_name
960973
)
961974

962-
x_weighted = (lambda_var * x_points).sum(dim=[SEGMENT_DIM, BREAKPOINT_DIM])
963-
model.add_constraints(x_expr == x_weighted, name=x_link_name)
964-
965-
y_weighted = (lambda_var * y_points).sum(dim=[SEGMENT_DIM, BREAKPOINT_DIM])
966-
model.add_constraints(y_expr == y_weighted, name=y_link_name)
967-
968-
return select_con
975+
weighted = (lambda_var * stacked_bp).sum(dim=[SEGMENT_DIM, dim])
976+
return model.add_constraints(target_expr == weighted, name=link_name)

test/test_piecewise_constraints.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,23 @@ def test_multi_dimensional(self) -> None:
579579
assert "generator" in binary.dims
580580
assert "generator" in lam.dims
581581

582+
def test_three_variables(self) -> None:
583+
"""Disjunctive with 3 variables creates single link constraint."""
584+
m = Model()
585+
x = m.add_variables(name="x")
586+
y = m.add_variables(name="y")
587+
z = m.add_variables(name="z")
588+
m.add_piecewise_constraints(
589+
(x, segments([[0, 10], [50, 100]])),
590+
(y, segments([[0, 5], [20, 80]])),
591+
(z, segments([[0, 3], [15, 60]])),
592+
)
593+
assert f"pwl0{PWL_BINARY_SUFFIX}" in m.variables
594+
assert f"pwl0{PWL_LAMBDA_SUFFIX}" in m.variables
595+
# Single link constraint with _pwl_var dimension
596+
link = m.constraints[f"pwl0{PWL_X_LINK_SUFFIX}"]
597+
assert "_pwl_var" in [str(d) for d in link.dims]
598+
582599

583600
# ===========================================================================
584601
# Validation
@@ -1278,19 +1295,21 @@ def test_segment_dim_mismatch_raises(self) -> None:
12781295
with pytest.raises(ValueError, match="segment dimension"):
12791296
m.add_piecewise_constraints((x, x_pts), (y, y_pts))
12801297

1281-
def test_disjunctive_three_pairs_raises(self) -> None:
1282-
"""Disjunctive with 3 pairs raises ValueError."""
1298+
def test_disjunctive_three_pairs(self) -> None:
1299+
"""Disjunctive with 3 pairs works (N-variable)."""
12831300
m = Model()
12841301
x = m.add_variables(name="x")
12851302
y = m.add_variables(name="y")
12861303
z = m.add_variables(name="z")
12871304
seg = segments([[0, 10], [50, 100]])
1288-
with pytest.raises(ValueError, match="exactly 2"):
1289-
m.add_piecewise_constraints(
1290-
(x, seg),
1291-
(y, seg),
1292-
(z, seg),
1293-
)
1305+
m.add_piecewise_constraints(
1306+
(x, seg),
1307+
(y, seg),
1308+
(z, seg),
1309+
)
1310+
assert f"pwl0{PWL_BINARY_SUFFIX}" in m.variables
1311+
assert f"pwl0{PWL_LAMBDA_SUFFIX}" in m.variables
1312+
assert f"pwl0{PWL_X_LINK_SUFFIX}" in m.constraints
12941313

12951314
def test_disjunctive_interior_nan_raises(self) -> None:
12961315
"""Disjunctive with interior NaN raises ValueError."""

0 commit comments

Comments
 (0)