Skip to content

Commit c9b2abf

Browse files
committed
feat: support mixed per-element signs in frozen Constraint, add rhs/lhs setters
- Store _sign as str (uniform, fast) or np.ndarray (mixed, per-row) - Add rhs/lhs setters on Constraint via _refreeze_after pattern - Invalidate _dual on mutation; update netcdf serialization for array signs - Add tests for setters, mixed-sign freeze/roundtrip/sanitize/repr/netcdf
1 parent 01add96 commit c9b2abf

3 files changed

Lines changed: 198 additions & 33 deletions

File tree

linopy/constraints.py

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def __init__(
511511
csr: scipy.sparse.csr_array,
512512
con_labels: np.ndarray,
513513
rhs: np.ndarray,
514-
sign: str,
514+
sign: str | np.ndarray,
515515
coords: list[pd.Index],
516516
model: Model,
517517
name: str = "",
@@ -596,7 +596,7 @@ def coord_names(self) -> list[str]:
596596
return [c.name for c in self._coords]
597597

598598
def _active_to_dataarray(
599-
self, active_values: np.ndarray, fill: float | int = -1
599+
self, active_values: np.ndarray, fill: float | int | str = -1
600600
) -> DataArray:
601601
full = np.full(self.full_size, fill, dtype=active_values.dtype)
602602
full[self.active_positions] = active_values
@@ -633,14 +633,40 @@ def vars(self) -> DataArray:
633633

634634
@property
635635
def sign(self) -> DataArray:
636-
"""Get sign DataArray (scalar, same sign for all entries)."""
637-
return DataArray(np.full(self.shape, self._sign), coords=self._coords)
636+
"""Get sign DataArray."""
637+
if isinstance(self._sign, str):
638+
return DataArray(np.full(self.shape, self._sign), coords=self._coords)
639+
return self._active_to_dataarray(self._sign, fill="")
638640

639641
@property
640642
def rhs(self) -> DataArray:
641643
"""Get RHS DataArray, shape (*coord_dims)."""
642644
return self._active_to_dataarray(self._rhs, fill=np.nan)
643645

646+
@rhs.setter
647+
def rhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None:
648+
self._refreeze_after(lambda mc: setattr(mc, "rhs", value))
649+
650+
@property
651+
def lhs(self) -> expressions.LinearExpression:
652+
"""Get LHS as LinearExpression (triggers Dataset reconstruction)."""
653+
return self.mutable().lhs
654+
655+
@lhs.setter
656+
def lhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None:
657+
self._refreeze_after(lambda mc: setattr(mc, "lhs", value))
658+
659+
def _refreeze_after(self, mutate: Callable[[MutableConstraint], None]) -> None:
660+
mc = self.mutable()
661+
mutate(mc)
662+
refrozen = Constraint.from_mutable(mc, self._cindex)
663+
self._csr = refrozen._csr
664+
self._con_labels = refrozen._con_labels
665+
self._rhs = refrozen._rhs
666+
self._sign = refrozen._sign
667+
self._coords = refrozen._coords
668+
self._dual = None
669+
644670
@property
645671
@has_optimized_model
646672
def dual(self) -> DataArray:
@@ -745,7 +771,8 @@ def row_expr(row: int) -> str:
745771
coeffs_row = np.zeros(nterm, dtype=csr.dtype)
746772
vars_row[: end - start] = vlabels[csr.indices[start:end]]
747773
coeffs_row[: end - start] = csr.data[start:end]
748-
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[self._sign]} {self._rhs[row]}"
774+
sign = self._sign if isinstance(self._sign, str) else self._sign[row]
775+
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[sign]} {self._rhs[row]}"
749776

750777
if size > 1:
751778
for indices in generate_indices_for_printout(shape, max_lines):
@@ -787,21 +814,22 @@ def to_netcdf_ds(self) -> Dataset:
787814
"rhs": DataArray(self._rhs, dims=["_flat"]),
788815
"_con_labels": DataArray(self._con_labels, dims=["_flat"]),
789816
}
817+
if isinstance(self._sign, np.ndarray):
818+
data_vars["_sign"] = DataArray(self._sign, dims=["_flat"])
790819
data_vars.update(coords_to_dataset_vars(self._coords))
791820
if self._dual is not None:
792821
data_vars["dual"] = DataArray(self._dual, dims=["_flat"])
793822
dim_names = [c.name for c in self._coords]
794-
return Dataset(
795-
data_vars,
796-
attrs={
797-
"_linopy_format": "csr",
798-
"sign": self._sign,
799-
"cindex": self._cindex if self._cindex is not None else -1,
800-
"shape": list(csr.shape),
801-
"coord_dims": dim_names,
802-
"name": self._name,
803-
},
804-
)
823+
attrs: dict[str, Any] = {
824+
"_linopy_format": "csr",
825+
"cindex": self._cindex if self._cindex is not None else -1,
826+
"shape": list(csr.shape),
827+
"coord_dims": dim_names,
828+
"name": self._name,
829+
}
830+
if isinstance(self._sign, str):
831+
attrs["sign"] = self._sign
832+
return Dataset(data_vars, attrs=attrs)
805833

806834
@classmethod
807835
def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
@@ -813,7 +841,7 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
813841
shape=shape,
814842
)
815843
rhs = ds["rhs"].values
816-
sign = attrs["sign"]
844+
sign: str | np.ndarray = ds["_sign"].values if "_sign" in ds else attrs["sign"]
817845
cindex = int(attrs["cindex"])
818846
cindex = cindex if cindex >= 0 else None
819847
coord_dims = attrs["coord_dims"]
@@ -841,7 +869,10 @@ def to_matrix_with_rhs(
841869
self, label_index: VariableLabelIndex
842870
) -> tuple[scipy.sparse.csr_array, np.ndarray, np.ndarray, np.ndarray]:
843871
"""Return (csr, con_labels, b, sense) — all pre-stored, no recomputation."""
844-
sense = np.full(len(self._rhs), self._sign[0])
872+
if isinstance(self._sign, str):
873+
sense = np.full(len(self._rhs), self._sign[0])
874+
else:
875+
sense = np.array([s[0] for s in self._sign])
845876
return self._csr, self._con_labels, self._rhs, sense
846877

847878
def sanitize_zeros(self) -> Constraint:
@@ -856,18 +887,25 @@ def sanitize_missings(self) -> Constraint:
856887

857888
def sanitize_infinities(self) -> Constraint:
858889
"""Mask out rows with invalid infinite RHS values (mutates in-place)."""
859-
if self._sign == LESS_EQUAL:
860-
invalid = self._rhs == np.inf
861-
elif self._sign == GREATER_EQUAL:
862-
invalid = self._rhs == -np.inf
890+
if isinstance(self._sign, str):
891+
if self._sign == LESS_EQUAL:
892+
invalid = self._rhs == np.inf
893+
elif self._sign == GREATER_EQUAL:
894+
invalid = self._rhs == -np.inf
895+
else:
896+
return self
863897
else:
864-
return self
898+
invalid = ((self._sign == LESS_EQUAL) & (self._rhs == np.inf)) | (
899+
(self._sign == GREATER_EQUAL) & (self._rhs == -np.inf)
900+
)
865901
if not invalid.any():
866902
return self
867903
keep = ~invalid
868904
self._csr = self._csr[keep]
869905
self._con_labels = self._con_labels[keep]
870906
self._rhs = self._rhs[keep]
907+
if not isinstance(self._sign, str):
908+
self._sign = self._sign[keep]
871909
return self
872910

873911
def freeze(self) -> Constraint:
@@ -907,13 +945,14 @@ def from_mutable(
907945
active_mask = (labels_flat != -1) & (vars_flat != -1).any(axis=1)
908946
rhs = con.rhs.values.ravel()[active_mask]
909947
sign_vals = con.sign.values.ravel()
910-
unique_signs = np.unique(sign_vals[active_mask])
911-
if len(unique_signs) > 1:
912-
raise ValueError(
913-
"Constraint has per-element signs; cannot freeze to immutable Constraint. "
914-
"This is a known limitation — use MutableConstraint instead."
915-
)
916-
sign = str(unique_signs[0]) if len(unique_signs) == 1 else "="
948+
active_signs = sign_vals[active_mask]
949+
unique_signs = np.unique(active_signs)
950+
if len(unique_signs) == 0:
951+
sign: str | np.ndarray = "="
952+
elif len(unique_signs) == 1:
953+
sign = str(unique_signs[0])
954+
else:
955+
sign = active_signs
917956
dual = (
918957
con.data["dual"].values.ravel()[active_mask] if "dual" in con.data else None
919958
)

test/test_constraint.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,15 +734,17 @@ def test_freeze_mutable_roundtrip_with_masking() -> None:
734734
assert frozen.ncons == refrozen.ncons == 3
735735

736736

737-
def test_from_mutable_mixed_signs_raises() -> None:
737+
def test_from_mutable_mixed_signs() -> None:
738738
m = Model()
739739
x = m.add_variables(coords=[pd.RangeIndex(3, name="i")], name="x")
740740
m.add_constraints(x >= 0, name="mixed", freeze=False)
741741
mc = m.constraints["mixed"]
742742
assert isinstance(mc, MutableConstraint)
743743
mc._data["sign"] = xr.DataArray(["<=", ">=", "<="], dims=["i"])
744-
with pytest.raises(ValueError, match="per-element signs"):
745-
Constraint.from_mutable(mc)
744+
frozen = Constraint.from_mutable(mc)
745+
assert isinstance(frozen._sign, np.ndarray)
746+
assert list(frozen._sign) == ["<=", ">=", "<="]
747+
assert_equal(frozen.sign, mc.sign)
746748

747749

748750
def test_variable_label_index(m: Model) -> None:
@@ -787,3 +789,102 @@ def test_constraint_repr_shows_variable_names(m: Model) -> None:
787789
c = m.constraints["c"]
788790
r = repr(c)
789791
assert "x" in r
792+
793+
794+
def test_freeze_mixed_signs_from_rule() -> None:
795+
m = Model()
796+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
797+
coords = [pd.RangeIndex(4, name="i")]
798+
799+
def bound(m, i):
800+
if i % 2:
801+
return x.at[i] >= i
802+
return x.at[i] == 0.0
803+
804+
con = m.add_constraints(bound, coords=coords, name="mixed_rule")
805+
assert isinstance(con, Constraint)
806+
assert isinstance(con._sign, np.ndarray)
807+
assert con.ncons == 4
808+
expected_signs = ["=", ">=", "=", ">="]
809+
assert list(con._sign) == expected_signs
810+
np.testing.assert_array_equal(con.sign.values, expected_signs)
811+
812+
813+
def test_frozen_rhs_setter() -> None:
814+
m = Model()
815+
time = pd.RangeIndex(5, name="t")
816+
x = m.add_variables(lower=0, coords=[time], name="x")
817+
con = m.add_constraints(x >= 1, name="c")
818+
assert isinstance(con, Constraint)
819+
con.rhs = 10
820+
np.testing.assert_array_equal(con._rhs, np.full(5, 10.0))
821+
factor = pd.Series(range(5), index=time)
822+
con.rhs = 2 * factor
823+
np.testing.assert_array_equal(con._rhs, 2 * np.arange(5, dtype=float))
824+
825+
826+
def test_frozen_lhs_setter() -> None:
827+
m = Model()
828+
time = pd.RangeIndex(5, name="t")
829+
x = m.add_variables(lower=0, coords=[time], name="x")
830+
y = m.add_variables(lower=0, coords=[time], name="y")
831+
con = m.add_constraints(x >= 0, name="c")
832+
assert isinstance(con, Constraint)
833+
con.lhs = 3 * x + 2 * y
834+
lhs = con.mutable().lhs
835+
assert lhs.nterm == 2
836+
837+
838+
def test_frozen_setter_invalidates_dual() -> None:
839+
m = Model()
840+
x = m.add_variables(lower=0, coords=[pd.RangeIndex(3, name="i")], name="x")
841+
con = m.add_constraints(x >= 0, name="c")
842+
con._dual = np.array([1.0, 2.0, 3.0])
843+
con.rhs = 10
844+
assert con._dual is None
845+
846+
847+
def test_mixed_sign_to_matrix_with_rhs() -> None:
848+
m = Model()
849+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
850+
coords = [pd.RangeIndex(4, name="i")]
851+
852+
def bound(m, i):
853+
if i % 2:
854+
return x.at[i] >= i
855+
return x.at[i] == 0.0
856+
857+
con = m.add_constraints(bound, coords=coords, name="c")
858+
li = m.variables.label_index
859+
csr, con_labels, b, sense = con.to_matrix_with_rhs(li)
860+
assert len(sense) == 4
861+
assert list(sense) == ["=", ">", "=", ">"]
862+
863+
864+
def test_mixed_sign_sanitize_infinities() -> None:
865+
m = Model()
866+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
867+
m.add_constraints(x >= 0, name="c", freeze=False)
868+
mc = m.constraints["c"]
869+
mc._data["sign"] = xr.DataArray(["<=", ">=", "<=", ">="], dims=["i"])
870+
mc._data["rhs"] = xr.DataArray([np.inf, -np.inf, 1.0, 2.0], dims=["i"])
871+
frozen = mc.freeze()
872+
frozen.sanitize_infinities()
873+
assert frozen.ncons == 2
874+
np.testing.assert_array_equal(frozen._rhs, [1.0, 2.0])
875+
876+
877+
def test_mixed_sign_repr() -> None:
878+
m = Model()
879+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
880+
coords = [pd.RangeIndex(4, name="i")]
881+
882+
def bound(m, i):
883+
if i % 2:
884+
return x.at[i] >= i
885+
return x.at[i] == 0.0
886+
887+
con = m.add_constraints(bound, coords=coords, name="c")
888+
r = repr(con)
889+
assert "≥" in r
890+
assert "=" in r

test/test_io.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,31 @@ def test_model_to_netcdf_frozen_constraint(tmp_path: Path) -> None:
9595
assert_model_equal(m, p)
9696

9797

98+
def test_model_to_netcdf_mixed_sign_constraint(tmp_path: Path) -> None:
99+
from linopy.constraints import Constraint
100+
101+
m = Model()
102+
x = m.add_variables(coords=[pd.RangeIndex(4, name="i")], name="x")
103+
104+
def bound(m, i):
105+
if i % 2:
106+
return x.at[i] >= i
107+
return x.at[i] == 0.0
108+
109+
m.add_constraints(bound, coords=[pd.RangeIndex(4, name="i")], name="c")
110+
assert isinstance(m.constraints["c"], Constraint)
111+
112+
fn = tmp_path / "test_mixed_sign.nc"
113+
m.to_netcdf(fn)
114+
p = read_netcdf(fn)
115+
116+
assert isinstance(p.constraints["c"], Constraint)
117+
import numpy as np
118+
119+
np.testing.assert_array_equal(m.constraints["c"]._sign, p.constraints["c"]._sign)
120+
assert_model_equal(m, p)
121+
122+
98123
def test_model_to_netcdf_with_sense(model: Model, tmp_path: Path) -> None:
99124
m = model
100125
m.objective.sense = "max"

0 commit comments

Comments
 (0)