Skip to content

Commit 5b01def

Browse files
committed
feat: add piecewise linear constraint API
Add `add_piecewise_constraint` method to Model class that creates piecewise linear constraints using SOS2 formulation. Features: - Single Variable or LinearExpression support - Dict of Variables/Expressions for linking multiple quantities - Auto-detection of link_dim from breakpoints coordinates - NaN-based masking with skip_nan_check option for performance - Counter-based name generation for efficiency The SOS2 formulation creates: 1. Lambda variables with bounds [0, 1] for each breakpoint 2. SOS2 constraint ensuring at most two adjacent lambdas are non-zero 3. Convexity constraint: sum(lambda) = 1 4. Linking constraints: expr = sum(lambda * breakpoints)
1 parent 43239f8 commit 5b01def

3 files changed

Lines changed: 794 additions & 0 deletions

File tree

linopy/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535

3636
TERM_DIM = "_term"
3737
STACKED_TERM_DIM = "_stacked_term"
38+
39+
# Piecewise linear constraint constants
40+
PWL_LAMBDA_SUFFIX = "_lambda"
41+
PWL_CONVEX_SUFFIX = "_convex"
42+
PWL_LINK_SUFFIX = "_link"
43+
DEFAULT_BREAKPOINT_DIM = "breakpoint"
3844
GROUPED_TERM_DIM = "_grouped_term"
3945
GROUP_DIM = "_group"
4046
FACTOR_DIM = "_factor"

linopy/model.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@
3535
to_path,
3636
)
3737
from linopy.constants import (
38+
DEFAULT_BREAKPOINT_DIM,
3839
GREATER_EQUAL,
3940
HELPER_DIMS,
4041
LESS_EQUAL,
42+
PWL_CONVEX_SUFFIX,
43+
PWL_LAMBDA_SUFFIX,
44+
PWL_LINK_SUFFIX,
4145
TERM_DIM,
4246
ModelStatus,
4347
TerminationCondition,
@@ -130,6 +134,7 @@ class Model:
130134
"_cCounter",
131135
"_varnameCounter",
132136
"_connameCounter",
137+
"_pwlCounter",
133138
"_blocks",
134139
# TODO: check if these should not be mutable
135140
"_chunk",
@@ -180,6 +185,7 @@ def __init__(
180185
self._cCounter: int = 0
181186
self._varnameCounter: int = 0
182187
self._connameCounter: int = 0
188+
self._pwlCounter: int = 0
183189
self._blocks: DataArray | None = None
184190

185191
self._chunk: T_Chunks = chunk
@@ -591,6 +597,269 @@ def add_sos_constraints(
591597

592598
variable.attrs.update(sos_type=sos_type, sos_dim=sos_dim)
593599

600+
def add_piecewise_constraint(
601+
self,
602+
expr: Variable | LinearExpression | dict[str, Variable | LinearExpression],
603+
breakpoints: DataArray,
604+
link_dim: str | None = None,
605+
dim: str = DEFAULT_BREAKPOINT_DIM,
606+
mask: DataArray | None = None,
607+
name: str | None = None,
608+
skip_nan_check: bool = False,
609+
) -> Constraint:
610+
"""
611+
Add a piecewise linear constraint using SOS2 formulation.
612+
613+
This method creates a piecewise linear constraint that links one or more
614+
variables/expressions together via a set of breakpoints. It uses the SOS2
615+
(Special Ordered Set of type 2) formulation with lambda (interpolation)
616+
variables.
617+
618+
The SOS2 formulation ensures that at most two adjacent lambda variables
619+
can be non-zero, effectively selecting a segment of the piecewise linear
620+
function.
621+
622+
Parameters
623+
----------
624+
expr : Variable, LinearExpression, or dict of these
625+
The variable(s) or expression(s) to be linked by the piecewise constraint.
626+
- If a single Variable/LinearExpression is passed, the breakpoints
627+
directly specify the piecewise points for that expression.
628+
- If a dict is passed, the keys must match coordinates in `link_dim`
629+
of the breakpoints, allowing multiple expressions to be linked.
630+
breakpoints : xr.DataArray
631+
The breakpoint values defining the piecewise linear function.
632+
Must have `dim` as one of its dimensions. If `expr` is a dict,
633+
must also have `link_dim` dimension with coordinates matching the
634+
dict keys.
635+
link_dim : str, optional
636+
The dimension in breakpoints that links to different expressions.
637+
Required when `expr` is a dict. If None and `expr` is a dict,
638+
will attempt to auto-detect from breakpoints dimensions.
639+
dim : str, default "breakpoint"
640+
The dimension in breakpoints that represents the breakpoint index.
641+
This dimension's coordinates must be numeric (used as SOS2 weights).
642+
mask : xr.DataArray, optional
643+
Boolean mask indicating which piecewise constraints are valid.
644+
If None, auto-detected from NaN values in breakpoints (unless
645+
skip_nan_check is True).
646+
name : str, optional
647+
Base name for the generated variables and constraints.
648+
If None, auto-generates names like "pwl0", "pwl1", etc.
649+
skip_nan_check : bool, default False
650+
If True, skip automatic NaN detection in breakpoints. Use this
651+
when you know breakpoints contain no NaN values for better performance.
652+
653+
Returns
654+
-------
655+
Constraint
656+
The convexity constraint (sum of lambda = 1). Lambda variables
657+
and other constraints can be accessed via:
658+
- `model.variables[f"{name}_lambda"]`
659+
- `model.constraints[f"{name}_convex"]`
660+
- `model.constraints[f"{name}_link"]`
661+
662+
Raises
663+
------
664+
ValueError
665+
If expr is not a Variable, LinearExpression, or dict of these.
666+
If breakpoints doesn't have the required dim dimension.
667+
If link_dim cannot be auto-detected when expr is a dict.
668+
If link_dim coordinates don't match dict keys.
669+
If dim coordinates are not numeric.
670+
671+
Examples
672+
--------
673+
Single variable piecewise constraint:
674+
675+
>>> m = Model()
676+
>>> x = m.add_variables(name="x")
677+
>>> breakpoints = xr.DataArray([0, 10, 50, 100], dims=["bp"])
678+
>>> m.add_piecewise_constraint(x, breakpoints, dim="bp")
679+
680+
Using an expression:
681+
682+
>>> m = Model()
683+
>>> x = m.add_variables(name="x")
684+
>>> y = m.add_variables(name="y")
685+
>>> breakpoints = xr.DataArray([0, 10, 50, 100], dims=["bp"])
686+
>>> m.add_piecewise_constraint(x + y, breakpoints, dim="bp")
687+
688+
Multiple linked variables (e.g., power-efficiency curve):
689+
690+
>>> m = Model()
691+
>>> generators = ["gen1", "gen2"]
692+
>>> power = m.add_variables(coords=[generators], name="power")
693+
>>> efficiency = m.add_variables(coords=[generators], name="efficiency")
694+
>>> breakpoints = xr.DataArray(
695+
... [[0, 50, 100], [0.8, 0.95, 0.9]],
696+
... coords={"var": ["power", "efficiency"], "bp": [0, 1, 2]},
697+
... )
698+
>>> m.add_piecewise_constraint(
699+
... {"power": power, "efficiency": efficiency},
700+
... breakpoints,
701+
... link_dim="var",
702+
... dim="bp",
703+
... )
704+
705+
Notes
706+
-----
707+
The piecewise linear constraint is formulated using SOS2 variables:
708+
709+
1. Lambda variables λ_i with bounds [0, 1] are created for each breakpoint
710+
2. SOS2 constraint ensures at most two adjacent λ_i can be non-zero
711+
3. Convexity constraint: Σ λ_i = 1
712+
4. Linking constraints: expr = Σ λ_i × breakpoint_i (for each expression)
713+
"""
714+
# --- Input validation ---
715+
if dim not in breakpoints.dims:
716+
raise ValueError(
717+
f"breakpoints must have dimension '{dim}', "
718+
f"but only has dimensions {list(breakpoints.dims)}"
719+
)
720+
721+
if not pd.api.types.is_numeric_dtype(breakpoints.coords[dim]):
722+
raise ValueError(
723+
f"Breakpoint dimension '{dim}' must have numeric coordinates "
724+
f"for SOS2 weights, but got {breakpoints.coords[dim].dtype}"
725+
)
726+
727+
# --- Generate names using counter ---
728+
if name is None:
729+
name = f"pwl{self._pwlCounter}"
730+
self._pwlCounter += 1
731+
732+
lambda_name = f"{name}{PWL_LAMBDA_SUFFIX}"
733+
convex_name = f"{name}{PWL_CONVEX_SUFFIX}"
734+
link_name = f"{name}{PWL_LINK_SUFFIX}"
735+
736+
# --- Determine lambda coordinates, mask, and target expression ---
737+
is_single = isinstance(expr, Variable | LinearExpression)
738+
is_dict = isinstance(expr, dict)
739+
740+
if not is_single and not is_dict:
741+
raise ValueError(
742+
f"'expr' must be a Variable, LinearExpression, or dict of these, "
743+
f"got {type(expr)}"
744+
)
745+
746+
if is_single:
747+
# Single expression case
748+
target_expr = self._to_linexpr(expr)
749+
lambda_coords = breakpoints.coords
750+
lambda_mask = self._compute_pwl_mask(mask, breakpoints, skip_nan_check)
751+
752+
else:
753+
# Dict case - need to validate link_dim and build stacked expression
754+
expr_dict = expr
755+
expr_keys = set(expr_dict.keys())
756+
757+
# Auto-detect or validate link_dim
758+
link_dim = self._resolve_pwl_link_dim(link_dim, breakpoints, dim, expr_keys)
759+
760+
# Build lambda coordinates (exclude link_dim)
761+
lambda_coords = [
762+
pd.Index(breakpoints.coords[d].values, name=d)
763+
for d in breakpoints.dims
764+
if d != link_dim
765+
]
766+
767+
# Compute mask
768+
base_mask = self._compute_pwl_mask(mask, breakpoints, skip_nan_check)
769+
lambda_mask = base_mask.any(dim=link_dim)
770+
771+
# Build stacked expression from dict
772+
target_expr = self._build_stacked_expr(expr_dict, breakpoints, link_dim)
773+
774+
# --- Common: Create lambda, SOS2, convexity, and linking constraints ---
775+
lambda_var = self.add_variables(
776+
lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=lambda_mask
777+
)
778+
779+
self.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim)
780+
781+
convex_con = self.add_constraints(
782+
lambda_var.sum(dim=dim) == 1, name=convex_name
783+
)
784+
785+
weighted_sum = (lambda_var * breakpoints).sum(dim=dim)
786+
self.add_constraints(target_expr == weighted_sum, name=link_name)
787+
788+
return convex_con
789+
790+
def _to_linexpr(self, expr: Variable | LinearExpression) -> LinearExpression:
791+
"""Convert Variable or LinearExpression to LinearExpression."""
792+
if isinstance(expr, LinearExpression):
793+
return expr
794+
return expr.to_linexpr()
795+
796+
def _compute_pwl_mask(
797+
self,
798+
mask: DataArray | None,
799+
breakpoints: DataArray,
800+
skip_nan_check: bool,
801+
) -> DataArray | None:
802+
"""Compute mask for piecewise constraint, optionally skipping NaN check."""
803+
if mask is not None:
804+
return mask
805+
if skip_nan_check:
806+
return None
807+
return ~breakpoints.isnull()
808+
809+
def _resolve_pwl_link_dim(
810+
self,
811+
link_dim: str | None,
812+
breakpoints: DataArray,
813+
dim: str,
814+
expr_keys: set[str],
815+
) -> str:
816+
"""Auto-detect or validate link_dim for dict case."""
817+
if link_dim is None:
818+
for d in breakpoints.dims:
819+
if d == dim:
820+
continue
821+
coords_set = set(str(c) for c in breakpoints.coords[d].values)
822+
if coords_set == expr_keys:
823+
return str(d)
824+
raise ValueError(
825+
"Could not auto-detect link_dim. Please specify it explicitly. "
826+
f"Breakpoint dimensions: {list(breakpoints.dims)}, "
827+
f"expression keys: {list(expr_keys)}"
828+
)
829+
830+
if link_dim not in breakpoints.dims:
831+
raise ValueError(
832+
f"link_dim '{link_dim}' not found in breakpoints dimensions "
833+
f"{list(breakpoints.dims)}"
834+
)
835+
coords_set = set(str(c) for c in breakpoints.coords[link_dim].values)
836+
if coords_set != expr_keys:
837+
raise ValueError(
838+
f"link_dim '{link_dim}' coordinates {coords_set} "
839+
f"don't match expression keys {expr_keys}"
840+
)
841+
return link_dim
842+
843+
def _build_stacked_expr(
844+
self,
845+
expr_dict: dict[str, Variable | LinearExpression],
846+
breakpoints: DataArray,
847+
link_dim: str,
848+
) -> LinearExpression:
849+
"""Build a stacked LinearExpression from a dict of Variables/Expressions."""
850+
link_coords = list(breakpoints.coords[link_dim].values)
851+
852+
# Collect expression data and stack
853+
expr_data_list = []
854+
for k in link_coords:
855+
e = expr_dict[str(k)]
856+
linexpr = self._to_linexpr(e)
857+
expr_data_list.append(linexpr.data.expand_dims({link_dim: [k]}))
858+
859+
# Concatenate along link_dim
860+
stacked_data = xr.concat(expr_data_list, dim=link_dim)
861+
return LinearExpression(stacked_data, self)
862+
594863
def add_constraints(
595864
self,
596865
lhs: VariableLike

0 commit comments

Comments
 (0)