Skip to content

Commit 351ef71

Browse files
committed
feat: support mixed per-element signs in frozen Constraint, add rhs/lhs setters
- Store _sign as str (uniform, fast) or np.ndarray (mixed, per-row) - Add rhs/lhs setters on Constraint via _refreeze_after pattern - Invalidate _dual on mutation; update netcdf serialization for array signs - Add tests for setters, mixed-sign freeze/roundtrip/sanitize/repr/netcdf
1 parent 4d65f7f commit 351ef71

File tree

3 files changed

+200
-33
lines changed

3 files changed

+200
-33
lines changed

linopy/constraints.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def __init__(
529529
csr: scipy.sparse.csr_array,
530530
con_labels: np.ndarray,
531531
rhs: np.ndarray,
532-
sign: str,
532+
sign: str | np.ndarray,
533533
coords: list[pd.Index],
534534
model: Model,
535535
name: str = "",
@@ -614,7 +614,7 @@ def coord_names(self) -> list[str]:
614614
return [str(c.name) for c in self._coords]
615615

616616
def _active_to_dataarray(
617-
self, active_values: np.ndarray, fill: float | int = -1
617+
self, active_values: np.ndarray, fill: float | int | str = -1
618618
) -> DataArray:
619619
full = np.full(self.full_size, fill, dtype=active_values.dtype)
620620
full[self.active_positions] = active_values
@@ -651,14 +651,40 @@ def vars(self) -> DataArray:
651651

652652
@property
653653
def sign(self) -> DataArray:
654-
"""Get sign DataArray (scalar, same sign for all entries)."""
655-
return DataArray(np.full(self.shape, self._sign), coords=self._coords)
654+
"""Get sign DataArray."""
655+
if isinstance(self._sign, str):
656+
return DataArray(np.full(self.shape, self._sign), coords=self._coords)
657+
return self._active_to_dataarray(self._sign, fill="")
656658

657659
@property
658660
def rhs(self) -> DataArray:
659661
"""Get RHS DataArray, shape (*coord_dims)."""
660662
return self._active_to_dataarray(self._rhs, fill=np.nan)
661663

664+
@rhs.setter
665+
def rhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None:
666+
self._refreeze_after(lambda mc: setattr(mc, "rhs", value))
667+
668+
@property
669+
def lhs(self) -> expressions.LinearExpression:
670+
"""Get LHS as LinearExpression (triggers Dataset reconstruction)."""
671+
return self.mutable().lhs
672+
673+
@lhs.setter
674+
def lhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None:
675+
self._refreeze_after(lambda mc: setattr(mc, "lhs", value))
676+
677+
def _refreeze_after(self, mutate: Callable[[MutableConstraint], None]) -> None:
678+
mc = self.mutable()
679+
mutate(mc)
680+
refrozen = Constraint.from_mutable(mc, self._cindex)
681+
self._csr = refrozen._csr
682+
self._con_labels = refrozen._con_labels
683+
self._rhs = refrozen._rhs
684+
self._sign = refrozen._sign
685+
self._coords = refrozen._coords
686+
self._dual = None
687+
662688
@property
663689
@has_optimized_model
664690
def dual(self) -> DataArray:
@@ -763,7 +789,8 @@ def row_expr(row: int) -> str:
763789
coeffs_row = np.zeros(nterm, dtype=csr.dtype)
764790
vars_row[: end - start] = vlabels[csr.indices[start:end]]
765791
coeffs_row[: end - start] = csr.data[start:end]
766-
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[self._sign]} {self._rhs[row]}"
792+
sign = self._sign if isinstance(self._sign, str) else self._sign[row]
793+
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[sign]} {self._rhs[row]}"
767794

768795
if size > 1:
769796
for indices in generate_indices_for_printout(shape, max_lines):
@@ -805,21 +832,22 @@ def to_netcdf_ds(self) -> Dataset:
805832
"rhs": DataArray(self._rhs, dims=["_flat"]),
806833
"_con_labels": DataArray(self._con_labels, dims=["_flat"]),
807834
}
835+
if isinstance(self._sign, np.ndarray):
836+
data_vars["_sign"] = DataArray(self._sign, dims=["_flat"])
808837
data_vars.update(coords_to_dataset_vars(self._coords))
809838
if self._dual is not None:
810839
data_vars["dual"] = DataArray(self._dual, dims=["_flat"])
811840
dim_names = [c.name for c in self._coords]
812-
return Dataset(
813-
data_vars,
814-
attrs={
815-
"_linopy_format": "csr",
816-
"sign": self._sign,
817-
"cindex": self._cindex if self._cindex is not None else -1,
818-
"shape": list(csr.shape),
819-
"coord_dims": dim_names,
820-
"name": self._name,
821-
},
822-
)
841+
attrs: dict[str, Any] = {
842+
"_linopy_format": "csr",
843+
"cindex": self._cindex if self._cindex is not None else -1,
844+
"shape": list(csr.shape),
845+
"coord_dims": dim_names,
846+
"name": self._name,
847+
}
848+
if isinstance(self._sign, str):
849+
attrs["sign"] = self._sign
850+
return Dataset(data_vars, attrs=attrs)
823851

824852
@classmethod
825853
def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
@@ -831,7 +859,9 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
831859
shape=shape,
832860
)
833861
rhs = ds["rhs"].values
834-
sign = attrs["sign"]
862+
sign: str | np.ndarray = (
863+
ds["_sign"].values if "_sign" in ds else attrs["sign"]
864+
)
835865
_cindex_raw = int(attrs["cindex"])
836866
cindex: int | None = _cindex_raw if _cindex_raw >= 0 else None
837867
coord_dims = attrs["coord_dims"]
@@ -859,7 +889,10 @@ def to_matrix_with_rhs(
859889
self, label_index: VariableLabelIndex
860890
) -> tuple[scipy.sparse.csr_array, np.ndarray, np.ndarray, np.ndarray]:
861891
"""Return (csr, con_labels, b, sense) — all pre-stored, no recomputation."""
862-
sense = np.full(len(self._rhs), self._sign[0])
892+
if isinstance(self._sign, str):
893+
sense = np.full(len(self._rhs), self._sign[0])
894+
else:
895+
sense = np.array([s[0] for s in self._sign])
863896
return self._csr, self._con_labels, self._rhs, sense
864897

865898
def sanitize_zeros(self) -> Constraint:
@@ -874,18 +907,25 @@ def sanitize_missings(self) -> Constraint:
874907

875908
def sanitize_infinities(self) -> Constraint:
876909
"""Mask out rows with invalid infinite RHS values (mutates in-place)."""
877-
if self._sign == LESS_EQUAL:
878-
invalid = self._rhs == np.inf
879-
elif self._sign == GREATER_EQUAL:
880-
invalid = self._rhs == -np.inf
910+
if isinstance(self._sign, str):
911+
if self._sign == LESS_EQUAL:
912+
invalid = self._rhs == np.inf
913+
elif self._sign == GREATER_EQUAL:
914+
invalid = self._rhs == -np.inf
915+
else:
916+
return self
881917
else:
882-
return self
918+
invalid = ((self._sign == LESS_EQUAL) & (self._rhs == np.inf)) | (
919+
(self._sign == GREATER_EQUAL) & (self._rhs == -np.inf)
920+
)
883921
if not invalid.any():
884922
return self
885923
keep = ~invalid
886924
self._csr = self._csr[keep]
887925
self._con_labels = self._con_labels[keep]
888926
self._rhs = self._rhs[keep]
927+
if not isinstance(self._sign, str):
928+
self._sign = self._sign[keep]
889929
return self
890930

891931
def freeze(self) -> Constraint:
@@ -925,13 +965,14 @@ def from_mutable(
925965
active_mask = (labels_flat != -1) & (vars_flat != -1).any(axis=1)
926966
rhs = con.rhs.values.ravel()[active_mask]
927967
sign_vals = con.sign.values.ravel()
928-
unique_signs = np.unique(sign_vals[active_mask])
929-
if len(unique_signs) > 1:
930-
raise ValueError(
931-
"Constraint has per-element signs; cannot freeze to immutable Constraint. "
932-
"This is a known limitation — use MutableConstraint instead."
933-
)
934-
sign = str(unique_signs[0]) if len(unique_signs) == 1 else "="
968+
active_signs = sign_vals[active_mask]
969+
unique_signs = np.unique(active_signs)
970+
if len(unique_signs) == 0:
971+
sign: str | np.ndarray = "="
972+
elif len(unique_signs) == 1:
973+
sign = str(unique_signs[0])
974+
else:
975+
sign = active_signs
935976
dual = (
936977
con.data["dual"].values.ravel()[active_mask] if "dual" in con.data else None
937978
)

test/test_constraint.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,15 +733,17 @@ def test_freeze_mutable_roundtrip_with_masking() -> None:
733733
assert frozen.ncons == refrozen.ncons == 3
734734

735735

736-
def test_from_mutable_mixed_signs_raises() -> None:
736+
def test_from_mutable_mixed_signs() -> None:
737737
m = Model()
738738
x = m.add_variables(coords=[pd.RangeIndex(3, name="i")], name="x")
739739
m.add_constraints(x >= 0, name="mixed", freeze=False)
740740
mc = m.constraints["mixed"]
741741
assert isinstance(mc, MutableConstraint)
742742
mc._data["sign"] = xr.DataArray(["<=", ">=", "<="], dims=["i"])
743-
with pytest.raises(ValueError, match="per-element signs"):
744-
Constraint.from_mutable(mc)
743+
frozen = Constraint.from_mutable(mc)
744+
assert isinstance(frozen._sign, np.ndarray)
745+
assert list(frozen._sign) == ["<=", ">=", "<="]
746+
assert_equal(frozen.sign, mc.sign)
745747

746748

747749
def test_variable_label_index(m: Model) -> None:
@@ -786,3 +788,102 @@ def test_constraint_repr_shows_variable_names(m: Model) -> None:
786788
c = m.constraints["c"]
787789
r = repr(c)
788790
assert "x" in r
791+
792+
793+
def test_freeze_mixed_signs_from_rule() -> None:
794+
m = Model()
795+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
796+
coords = [pd.RangeIndex(4, name="i")]
797+
798+
def bound(m, i):
799+
if i % 2:
800+
return x.at[i] >= i
801+
return x.at[i] == 0.0
802+
803+
con = m.add_constraints(bound, coords=coords, name="mixed_rule")
804+
assert isinstance(con, Constraint)
805+
assert isinstance(con._sign, np.ndarray)
806+
assert con.ncons == 4
807+
expected_signs = ["=", ">=", "=", ">="]
808+
assert list(con._sign) == expected_signs
809+
np.testing.assert_array_equal(con.sign.values, expected_signs)
810+
811+
812+
def test_frozen_rhs_setter() -> None:
813+
m = Model()
814+
time = pd.RangeIndex(5, name="t")
815+
x = m.add_variables(lower=0, coords=[time], name="x")
816+
con = m.add_constraints(x >= 1, name="c")
817+
assert isinstance(con, Constraint)
818+
con.rhs = 10
819+
np.testing.assert_array_equal(con._rhs, np.full(5, 10.0))
820+
factor = pd.Series(range(5), index=time)
821+
con.rhs = 2 * factor
822+
np.testing.assert_array_equal(con._rhs, 2 * np.arange(5, dtype=float))
823+
824+
825+
def test_frozen_lhs_setter() -> None:
826+
m = Model()
827+
time = pd.RangeIndex(5, name="t")
828+
x = m.add_variables(lower=0, coords=[time], name="x")
829+
y = m.add_variables(lower=0, coords=[time], name="y")
830+
con = m.add_constraints(x >= 0, name="c")
831+
assert isinstance(con, Constraint)
832+
con.lhs = 3 * x + 2 * y
833+
lhs = con.mutable().lhs
834+
assert lhs.nterm == 2
835+
836+
837+
def test_frozen_setter_invalidates_dual() -> None:
838+
m = Model()
839+
x = m.add_variables(lower=0, coords=[pd.RangeIndex(3, name="i")], name="x")
840+
con = m.add_constraints(x >= 0, name="c")
841+
con._dual = np.array([1.0, 2.0, 3.0])
842+
con.rhs = 10
843+
assert con._dual is None
844+
845+
846+
def test_mixed_sign_to_matrix_with_rhs() -> None:
847+
m = Model()
848+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
849+
coords = [pd.RangeIndex(4, name="i")]
850+
851+
def bound(m, i):
852+
if i % 2:
853+
return x.at[i] >= i
854+
return x.at[i] == 0.0
855+
856+
con = m.add_constraints(bound, coords=coords, name="c")
857+
li = m.variables.label_index
858+
csr, con_labels, b, sense = con.to_matrix_with_rhs(li)
859+
assert len(sense) == 4
860+
assert list(sense) == ["=", ">", "=", ">"]
861+
862+
863+
def test_mixed_sign_sanitize_infinities() -> None:
864+
m = Model()
865+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
866+
m.add_constraints(x >= 0, name="c", freeze=False)
867+
mc = m.constraints["c"]
868+
mc._data["sign"] = xr.DataArray(["<=", ">=", "<=", ">="], dims=["i"])
869+
mc._data["rhs"] = xr.DataArray([np.inf, -np.inf, 1.0, 2.0], dims=["i"])
870+
frozen = mc.freeze()
871+
frozen.sanitize_infinities()
872+
assert frozen.ncons == 2
873+
np.testing.assert_array_equal(frozen._rhs, [1.0, 2.0])
874+
875+
876+
def test_mixed_sign_repr() -> None:
877+
m = Model()
878+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
879+
coords = [pd.RangeIndex(4, name="i")]
880+
881+
def bound(m, i):
882+
if i % 2:
883+
return x.at[i] >= i
884+
return x.at[i] == 0.0
885+
886+
con = m.add_constraints(bound, coords=coords, name="c")
887+
r = repr(con)
888+
assert "≥" in r
889+
assert "=" in r

test/test_io.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,31 @@ def test_model_to_netcdf_frozen_constraint(tmp_path: Path) -> None:
9595
assert_model_equal(m, p)
9696

9797

98+
def test_model_to_netcdf_mixed_sign_constraint(tmp_path: Path) -> None:
99+
from linopy.constraints import Constraint
100+
101+
m = Model()
102+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
103+
104+
def bound(m, i):
105+
if i % 2:
106+
return x.at[i] >= i
107+
return x.at[i] == 0.0
108+
109+
m.add_constraints(bound, coords=[pd.RangeIndex(4, name="i")], name="c")
110+
assert isinstance(m.constraints["c"], Constraint)
111+
112+
fn = tmp_path / "test_mixed_sign.nc"
113+
m.to_netcdf(fn)
114+
p = read_netcdf(fn)
115+
116+
assert isinstance(p.constraints["c"], Constraint)
117+
import numpy as np
118+
119+
np.testing.assert_array_equal(m.constraints["c"]._sign, p.constraints["c"]._sign)
120+
assert_model_equal(m, p)
121+
122+
98123
def test_model_to_netcdf_with_sense(model: Model, tmp_path: Path) -> None:
99124
m = model
100125
m.objective.sense = "max"

0 commit comments

Comments
 (0)