Skip to content

Commit 4d65f7f

Browse files
committed
fix: review fixes for PR #630 (matrix accessor rewrite)
- Fix __repr__ passing CSR positions instead of variable labels - Fix set_blocks failing on frozen Constraint - Extract _active_to_dataarray helper to reduce DRY violations - Simplify reset_dual to direct mutation instead of reconstruction - Add tests for freeze/mutable roundtrip, VariableLabelIndex, to_matrix_with_rhs, from_mutable mixed signs, repr correctness
1 parent eca2945 commit 4d65f7f

2 files changed

Lines changed: 104 additions & 43 deletions

File tree

linopy/constraints.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -613,16 +613,19 @@ def nterm(self) -> int:
613613
def coord_names(self) -> list[str]:
614614
return [str(c.name) for c in self._coords]
615615

616+
def _active_to_dataarray(
617+
self, active_values: np.ndarray, fill: float | int = -1
618+
) -> DataArray:
619+
full = np.full(self.full_size, fill, dtype=active_values.dtype)
620+
full[self.active_positions] = active_values
621+
return DataArray(full.reshape(self.shape), coords=self._coords)
622+
616623
@property
617624
def labels(self) -> DataArray:
618625
"""Get labels DataArray, shape (*coord_dims)."""
619626
if self._cindex is None:
620627
return DataArray([])
621-
shape = self.shape
622-
full_size = self.full_size
623-
labels_flat = np.full(full_size, -1, dtype=np.int64)
624-
labels_flat[self.active_positions] = self._con_labels
625-
return DataArray(labels_flat.reshape(shape), coords=self._coords)
628+
return self._active_to_dataarray(self._con_labels, fill=-1)
626629

627630
@property
628631
def coeffs(self) -> DataArray:
@@ -654,10 +657,7 @@ def sign(self) -> DataArray:
654657
@property
655658
def rhs(self) -> DataArray:
656659
"""Get RHS DataArray, shape (*coord_dims)."""
657-
shape = self.shape
658-
rhs_full = np.full(self.full_size, np.nan)
659-
rhs_full[self.active_positions] = self._rhs
660-
return DataArray(rhs_full.reshape(shape), coords=self._coords)
660+
return self._active_to_dataarray(self._rhs, fill=np.nan)
661661

662662
@property
663663
@has_optimized_model
@@ -667,9 +667,7 @@ def dual(self) -> DataArray:
667667
raise AttributeError(
668668
"Underlying is optimized but does not have dual values stored."
669669
)
670-
dual_full = np.full(self.full_size, np.nan)
671-
dual_full[self.active_positions] = self._dual
672-
return DataArray(dual_full.reshape(self.shape), coords=self._coords)
670+
return self._active_to_dataarray(self._dual, fill=np.nan)
673671

674672
@dual.setter
675673
def dual(self, value: DataArray) -> None:
@@ -731,24 +729,10 @@ def _to_dataset(self, nterm: int) -> Dataset:
731729
def data(self) -> Dataset:
732730
"""Reconstruct the xarray Dataset from the CSR representation."""
733731
ds = self._to_dataset(self.nterm)
734-
shape = self.shape
735-
active_pos = self.active_positions
736-
rhs_full = np.full(self.full_size, np.nan)
737-
rhs_full[active_pos] = self._rhs
738-
ds = ds.assign(
739-
sign=DataArray(np.full(shape, self._sign), coords=self._coords),
740-
rhs=DataArray(rhs_full.reshape(shape), coords=self._coords),
741-
)
732+
ds = ds.assign(sign=self.sign, rhs=self.rhs)
742733
if self._dual is not None:
743-
dual_full = np.full(self.full_size, np.nan)
744-
dual_full[active_pos] = self._dual
745-
ds = ds.assign(
746-
dual=DataArray(dual_full.reshape(shape), coords=self._coords)
747-
)
748-
attrs: dict[str, Any] = {"name": self._name}
749-
if self._cindex is not None:
750-
attrs["label_range"] = (self._cindex, self._cindex + self.full_size)
751-
return ds.assign_attrs(attrs)
734+
ds = ds.assign(dual=self._active_to_dataarray(self._dual, fill=np.nan))
735+
return ds.assign_attrs(self.attrs)
752736

753737
def __repr__(self) -> str:
754738
"""Print the constraint without reconstructing the full Dataset."""
@@ -771,11 +755,13 @@ def __repr__(self) -> str:
771755
header_string = f"{self.type} `{self._name}`" if self._name else f"{self.type}"
772756
lines = []
773757

758+
vlabels = self._model.variables.label_index.vlabels
759+
774760
def row_expr(row: int) -> str:
775761
start, end = int(csr.indptr[row]), int(csr.indptr[row + 1])
776762
vars_row = np.full(nterm, -1, dtype=np.int64)
777763
coeffs_row = np.zeros(nterm, dtype=csr.dtype)
778-
vars_row[: end - start] = csr.indices[start:end]
764+
vars_row[: end - start] = vlabels[csr.indices[start:end]]
779765
coeffs_row[: end - start] = csr.data[start:end]
780766
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[self._sign]} {self._rhs[row]}"
781767

@@ -1630,7 +1616,12 @@ def set_blocks(self, block_map: np.ndarray) -> None:
16301616

16311617
res = res.where(not_missing.any(constraint.term_dim), -1)
16321618
res = res.where(not_zero.any(constraint.term_dim), 0)
1633-
constraint._data = assign_multiindex_safe(constraint.data, blocks=res)
1619+
if isinstance(constraint, MutableConstraint):
1620+
constraint._data = assign_multiindex_safe(constraint.data, blocks=res)
1621+
else:
1622+
mc = constraint.mutable()
1623+
mc._data = assign_multiindex_safe(mc.data, blocks=res)
1624+
self.data[name] = Constraint.from_mutable(mc, constraint._cindex)
16341625

16351626
@property
16361627
def flat(self) -> pd.DataFrame:
@@ -1687,18 +1678,7 @@ def reset_dual(self) -> None:
16871678
"""
16881679
for k, c in self.items():
16891680
if isinstance(c, Constraint):
1690-
if c._dual is not None:
1691-
self.data[k] = Constraint(
1692-
c._csr,
1693-
c._con_labels,
1694-
c._rhs,
1695-
c._sign,
1696-
c._coords,
1697-
c._model,
1698-
c._name,
1699-
cindex=c._cindex,
1700-
dual=None,
1701-
)
1681+
c._dual = None
17021682
elif isinstance(c, MutableConstraint):
17031683
if "dual" in c.data:
17041684
c._data = c.data.drop_vars("dual")

test/test_constraint.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,3 +705,84 @@ def test_constraints_inequalities(m: Model) -> None:
705705

706706
def test_constraints_equalities(m: Model) -> None:
707707
assert isinstance(m.constraints.equalities, Constraints)
708+
709+
710+
def test_freeze_mutable_roundtrip(m: Model) -> None:
711+
frozen = m.constraints["c"]
712+
assert isinstance(frozen, Constraint)
713+
mc = frozen.mutable()
714+
assert isinstance(mc, MutableConstraint)
715+
refrozen = Constraint.from_mutable(mc, frozen._cindex)
716+
assert_equal(frozen.labels, refrozen.labels)
717+
assert_equal(frozen.rhs, refrozen.rhs)
718+
assert_equal(frozen.sign, refrozen.sign)
719+
np.testing.assert_array_equal(frozen._csr.toarray(), refrozen._csr.toarray())
720+
np.testing.assert_array_equal(frozen._con_labels, refrozen._con_labels)
721+
722+
723+
def test_freeze_mutable_roundtrip_with_masking() -> None:
724+
m = Model()
725+
x = m.add_variables(coords=[pd.RangeIndex(5, name="i")], name="x")
726+
mask = xr.DataArray([True, False, True, False, True], dims=["i"])
727+
m.add_constraints(x.where(mask) >= 0, name="c")
728+
frozen = m.constraints["c"]
729+
mc = frozen.mutable()
730+
refrozen = Constraint.from_mutable(mc, frozen._cindex)
731+
assert_equal(frozen.labels, refrozen.labels)
732+
assert_equal(frozen.rhs, refrozen.rhs)
733+
assert frozen.ncons == refrozen.ncons == 3
734+
735+
736+
def test_from_mutable_mixed_signs_raises() -> None:
737+
m = Model()
738+
x = m.add_variables(coords=[pd.RangeIndex(3, name="i")], name="x")
739+
m.add_constraints(x >= 0, name="mixed", freeze=False)
740+
mc = m.constraints["mixed"]
741+
assert isinstance(mc, MutableConstraint)
742+
mc._data["sign"] = xr.DataArray(["<=", ">=", "<="], dims=["i"])
743+
with pytest.raises(ValueError, match="per-element signs"):
744+
Constraint.from_mutable(mc)
745+
746+
747+
def test_variable_label_index(m: Model) -> None:
748+
li = m.variables.label_index
749+
assert li.n_active_vars > 0
750+
assert len(li.vlabels) == li.n_active_vars
751+
assert li.label_to_pos.shape[0] == m._xCounter
752+
for lbl in li.vlabels:
753+
assert li.label_to_pos[lbl] >= 0
754+
assert (li.label_to_pos[li.vlabels] == np.arange(li.n_active_vars)).all()
755+
756+
757+
def test_variable_label_index_invalidation(m: Model) -> None:
758+
li = m.variables.label_index
759+
old_vlabels = li.vlabels.copy()
760+
m.add_variables(name="w")
761+
li.invalidate()
762+
assert len(li.vlabels) > len(old_vlabels)
763+
764+
765+
def test_to_matrix_with_rhs(m: Model) -> None:
766+
c = m.constraints["c"]
767+
li = m.variables.label_index
768+
csr, con_labels, b, sense = c.to_matrix_with_rhs(li)
769+
assert csr.shape[0] == len(con_labels)
770+
assert csr.shape[0] == len(b)
771+
assert csr.shape[0] == len(sense)
772+
assert all(s in ("<", ">", "=") for s in sense)
773+
np.testing.assert_array_equal(b, c._rhs)
774+
775+
776+
def test_to_matrix_with_rhs_mutable(m: Model) -> None:
777+
mc = m.constraints["c"].mutable()
778+
li = m.variables.label_index
779+
csr, con_labels, b, sense = mc.to_matrix_with_rhs(li)
780+
assert csr.shape[0] == len(con_labels)
781+
assert csr.shape[0] == len(b)
782+
assert csr.shape[0] == len(sense)
783+
784+
785+
def test_constraint_repr_shows_variable_names(m: Model) -> None:
786+
c = m.constraints["c"]
787+
r = repr(c)
788+
assert "x" in r

0 commit comments

Comments
 (0)