Skip to content

Commit c8a4ddc

Browse files
committed
feat: mixed per-element signs in frozen Constraint, rhs/lhs setters, DRY helpers
1 parent 351ef71 commit c8a4ddc

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

linopy/constraints.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,9 @@ class Constraint(ConstraintBase):
497497
constraint grid (including masked/empty rows).
498498
rhs : np.ndarray
499499
Shape (n_flat,). Right-hand-side values.
500-
sign : str
501-
Constraint sign: one of '=', '<=', '>='.
502-
Note: per-element signs are not supported (documented regression vs MutableConstraint).
500+
sign : str or np.ndarray
501+
Constraint sign. Either a single str ('=', '<=', '>=') for uniform
502+
signs, or a per-row np.ndarray of sign strings for mixed signs.
503503
coords : list of pd.Index
504504
One index per coordinate dimension defining the constraint grid.
505505
model : Model
@@ -781,13 +781,11 @@ def __repr__(self) -> str:
781781
header_string = f"{self.type} `{self._name}`" if self._name else f"{self.type}"
782782
lines = []
783783

784-
vlabels = self._model.variables.label_index.vlabels
785-
786784
def row_expr(row: int) -> str:
787785
start, end = int(csr.indptr[row]), int(csr.indptr[row + 1])
788786
vars_row = np.full(nterm, -1, dtype=np.int64)
789787
coeffs_row = np.zeros(nterm, dtype=csr.dtype)
790-
vars_row[: end - start] = vlabels[csr.indices[start:end]]
788+
vars_row[: end - start] = csr.indices[start:end]
791789
coeffs_row[: end - start] = csr.data[start:end]
792790
sign = self._sign if isinstance(self._sign, str) else self._sign[row]
793791
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[sign]} {self._rhs[row]}"
@@ -859,9 +857,7 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
859857
shape=shape,
860858
)
861859
rhs = ds["rhs"].values
862-
sign: str | np.ndarray = (
863-
ds["_sign"].values if "_sign" in ds else attrs["sign"]
864-
)
860+
sign: str | np.ndarray = ds["_sign"].values if "_sign" in ds else attrs["sign"]
865861
_cindex_raw = int(attrs["cindex"])
866862
cindex: int | None = _cindex_raw if _cindex_raw >= 0 else None
867863
coord_dims = attrs["coord_dims"]
@@ -1657,12 +1653,7 @@ def set_blocks(self, block_map: np.ndarray) -> None:
16571653

16581654
res = res.where(not_missing.any(constraint.term_dim), -1)
16591655
res = res.where(not_zero.any(constraint.term_dim), 0)
1660-
if isinstance(constraint, MutableConstraint):
1661-
constraint._data = assign_multiindex_safe(constraint.data, blocks=res)
1662-
else:
1663-
mc = constraint.mutable()
1664-
mc._data = assign_multiindex_safe(mc.data, blocks=res)
1665-
self.data[name] = Constraint.from_mutable(mc, constraint._cindex)
1656+
constraint._data = assign_multiindex_safe(constraint.data, blocks=res)
16661657

16671658
@property
16681659
def flat(self) -> pd.DataFrame:
@@ -1719,7 +1710,18 @@ def reset_dual(self) -> None:
17191710
"""
17201711
for k, c in self.items():
17211712
if isinstance(c, Constraint):
1722-
c._dual = None
1713+
if c._dual is not None:
1714+
self.data[k] = Constraint(
1715+
c._csr,
1716+
c._con_labels,
1717+
c._rhs,
1718+
c._sign,
1719+
c._coords,
1720+
c._model,
1721+
c._name,
1722+
cindex=c._cindex,
1723+
dual=None,
1724+
)
17231725
elif isinstance(c, MutableConstraint):
17241726
if "dual" in c.data:
17251727
c._data = c.data.drop_vars("dual")

test/test_constraint.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -709,10 +709,10 @@ def test_constraints_equalities(m: Model) -> None:
709709

710710
def test_freeze_mutable_roundtrip(m: Model) -> None:
711711
frozen = m.constraints["c"]
712-
assert isinstance(frozen, Constraint)
712+
assert isinstance(frozen, linopy.constraints.Constraint)
713713
mc = frozen.mutable()
714714
assert isinstance(mc, MutableConstraint)
715-
refrozen = Constraint.from_mutable(mc, frozen._cindex)
715+
refrozen = linopy.constraints.Constraint.from_mutable(mc, frozen._cindex)
716716
assert_equal(frozen.labels, refrozen.labels)
717717
assert_equal(frozen.rhs, refrozen.rhs)
718718
assert_equal(frozen.sign, refrozen.sign)
@@ -727,7 +727,7 @@ def test_freeze_mutable_roundtrip_with_masking() -> None:
727727
m.add_constraints(x.where(mask) >= 0, name="c")
728728
frozen = m.constraints["c"]
729729
mc = frozen.mutable()
730-
refrozen = Constraint.from_mutable(mc, frozen._cindex)
730+
refrozen = linopy.constraints.Constraint.from_mutable(mc, frozen._cindex)
731731
assert_equal(frozen.labels, refrozen.labels)
732732
assert_equal(frozen.rhs, refrozen.rhs)
733733
assert frozen.ncons == refrozen.ncons == 3
@@ -740,7 +740,7 @@ def test_from_mutable_mixed_signs() -> None:
740740
mc = m.constraints["mixed"]
741741
assert isinstance(mc, MutableConstraint)
742742
mc._data["sign"] = xr.DataArray(["<=", ">=", "<="], dims=["i"])
743-
frozen = Constraint.from_mutable(mc)
743+
frozen = linopy.constraints.Constraint.from_mutable(mc)
744744
assert isinstance(frozen._sign, np.ndarray)
745745
assert list(frozen._sign) == ["<=", ">=", "<="]
746746
assert_equal(frozen.sign, mc.sign)
@@ -801,7 +801,7 @@ def bound(m, i):
801801
return x.at[i] == 0.0
802802

803803
con = m.add_constraints(bound, coords=coords, name="mixed_rule")
804-
assert isinstance(con, Constraint)
804+
assert isinstance(con, linopy.constraints.Constraint)
805805
assert isinstance(con._sign, np.ndarray)
806806
assert con.ncons == 4
807807
expected_signs = ["=", ">=", "=", ">="]
@@ -814,7 +814,7 @@ def test_frozen_rhs_setter() -> None:
814814
time = pd.RangeIndex(5, name="t")
815815
x = m.add_variables(lower=0, coords=[time], name="x")
816816
con = m.add_constraints(x >= 1, name="c")
817-
assert isinstance(con, Constraint)
817+
assert isinstance(con, linopy.constraints.Constraint)
818818
con.rhs = 10
819819
np.testing.assert_array_equal(con._rhs, np.full(5, 10.0))
820820
factor = pd.Series(range(5), index=time)
@@ -828,7 +828,7 @@ def test_frozen_lhs_setter() -> None:
828828
x = m.add_variables(lower=0, coords=[time], name="x")
829829
y = m.add_variables(lower=0, coords=[time], name="y")
830830
con = m.add_constraints(x >= 0, name="c")
831-
assert isinstance(con, Constraint)
831+
assert isinstance(con, linopy.constraints.Constraint)
832832
con.lhs = 3 * x + 2 * y
833833
lhs = con.mutable().lhs
834834
assert lhs.nterm == 2

0 commit comments

Comments
 (0)