Skip to content

Commit 3b2a415

Browse files
fix: review fixes for #630 (matrix accessor rewrite) (#632)
* fix: review fixes for PR #630 (matrix accessor rewrite) - Fix __repr__ passing CSR positions instead of variable labels - Fix set_blocks failing on frozen Constraint - Extract _active_to_dataarray helper to reduce DRY violations - Simplify reset_dual to direct mutation instead of reconstruction - Add tests for freeze/mutable roundtrip, VariableLabelIndex, to_matrix_with_rhs, from_mutable mixed signs, repr correctness * feat: support mixed per-element signs in frozen Constraint, add rhs/lhs setters - Store _sign as str (uniform, fast) or np.ndarray (mixed, per-row) - Add rhs/lhs setters on Constraint via _refreeze_after pattern - Invalidate _dual on mutation; update netcdf serialization for array signs - Add tests for setters, mixed-sign freeze/roundtrip/sanitize/repr/netcdf * feat: mixed per-element signs in frozen Constraint, rhs/lhs setters, DRY helpers
1 parent eca2945 commit 3b2a415

3 files changed

Lines changed: 291 additions & 61 deletions

File tree

linopy/constraints.py

Lines changed: 84 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,9 @@ class Constraint(ConstraintBase):
497497
constraint grid (including masked/empty rows).
498498
rhs : np.ndarray
499499
Shape (n_flat,). Right-hand-side values.
500-
sign : str
501-
Constraint sign: one of '=', '<=', '>='.
502-
Note: per-element signs are not supported (documented regression vs MutableConstraint).
500+
sign : str or np.ndarray
501+
Constraint sign. Either a single str ('=', '<=', '>=') for uniform
502+
signs, or a per-row np.ndarray of sign strings for mixed signs.
503503
coords : list of pd.Index
504504
One index per coordinate dimension defining the constraint grid.
505505
model : Model
@@ -529,7 +529,7 @@ def __init__(
529529
csr: scipy.sparse.csr_array,
530530
con_labels: np.ndarray,
531531
rhs: np.ndarray,
532-
sign: str,
532+
sign: str | np.ndarray,
533533
coords: list[pd.Index],
534534
model: Model,
535535
name: str = "",
@@ -613,16 +613,19 @@ def nterm(self) -> int:
613613
def coord_names(self) -> list[str]:
614614
return [str(c.name) for c in self._coords]
615615

616+
def _active_to_dataarray(
617+
self, active_values: np.ndarray, fill: float | int | str = -1
618+
) -> DataArray:
619+
full = np.full(self.full_size, fill, dtype=active_values.dtype)
620+
full[self.active_positions] = active_values
621+
return DataArray(full.reshape(self.shape), coords=self._coords)
622+
616623
@property
617624
def labels(self) -> DataArray:
618625
"""Get labels DataArray, shape (*coord_dims)."""
619626
if self._cindex is None:
620627
return DataArray([])
621-
shape = self.shape
622-
full_size = self.full_size
623-
labels_flat = np.full(full_size, -1, dtype=np.int64)
624-
labels_flat[self.active_positions] = self._con_labels
625-
return DataArray(labels_flat.reshape(shape), coords=self._coords)
628+
return self._active_to_dataarray(self._con_labels, fill=-1)
626629

627630
@property
628631
def coeffs(self) -> DataArray:
@@ -648,16 +651,39 @@ def vars(self) -> DataArray:
648651

649652
@property
650653
def sign(self) -> DataArray:
651-
"""Get sign DataArray (scalar, same sign for all entries)."""
652-
return DataArray(np.full(self.shape, self._sign), coords=self._coords)
654+
"""Get sign DataArray."""
655+
if isinstance(self._sign, str):
656+
return DataArray(np.full(self.shape, self._sign), coords=self._coords)
657+
return self._active_to_dataarray(self._sign, fill="")
653658

654659
@property
655660
def rhs(self) -> DataArray:
656661
"""Get RHS DataArray, shape (*coord_dims)."""
657-
shape = self.shape
658-
rhs_full = np.full(self.full_size, np.nan)
659-
rhs_full[self.active_positions] = self._rhs
660-
return DataArray(rhs_full.reshape(shape), coords=self._coords)
662+
return self._active_to_dataarray(self._rhs, fill=np.nan)
663+
664+
@rhs.setter
665+
def rhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None:
666+
self._refreeze_after(lambda mc: setattr(mc, "rhs", value))
667+
668+
@property
669+
def lhs(self) -> expressions.LinearExpression:
670+
"""Get LHS as LinearExpression (triggers Dataset reconstruction)."""
671+
return self.mutable().lhs
672+
673+
@lhs.setter
674+
def lhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None:
675+
self._refreeze_after(lambda mc: setattr(mc, "lhs", value))
676+
677+
def _refreeze_after(self, mutate: Callable[[MutableConstraint], None]) -> None:
678+
mc = self.mutable()
679+
mutate(mc)
680+
refrozen = Constraint.from_mutable(mc, self._cindex)
681+
self._csr = refrozen._csr
682+
self._con_labels = refrozen._con_labels
683+
self._rhs = refrozen._rhs
684+
self._sign = refrozen._sign
685+
self._coords = refrozen._coords
686+
self._dual = None
661687

662688
@property
663689
@has_optimized_model
@@ -667,9 +693,7 @@ def dual(self) -> DataArray:
667693
raise AttributeError(
668694
"Underlying is optimized but does not have dual values stored."
669695
)
670-
dual_full = np.full(self.full_size, np.nan)
671-
dual_full[self.active_positions] = self._dual
672-
return DataArray(dual_full.reshape(self.shape), coords=self._coords)
696+
return self._active_to_dataarray(self._dual, fill=np.nan)
673697

674698
@dual.setter
675699
def dual(self, value: DataArray) -> None:
@@ -731,24 +755,10 @@ def _to_dataset(self, nterm: int) -> Dataset:
731755
def data(self) -> Dataset:
732756
"""Reconstruct the xarray Dataset from the CSR representation."""
733757
ds = self._to_dataset(self.nterm)
734-
shape = self.shape
735-
active_pos = self.active_positions
736-
rhs_full = np.full(self.full_size, np.nan)
737-
rhs_full[active_pos] = self._rhs
738-
ds = ds.assign(
739-
sign=DataArray(np.full(shape, self._sign), coords=self._coords),
740-
rhs=DataArray(rhs_full.reshape(shape), coords=self._coords),
741-
)
758+
ds = ds.assign(sign=self.sign, rhs=self.rhs)
742759
if self._dual is not None:
743-
dual_full = np.full(self.full_size, np.nan)
744-
dual_full[active_pos] = self._dual
745-
ds = ds.assign(
746-
dual=DataArray(dual_full.reshape(shape), coords=self._coords)
747-
)
748-
attrs: dict[str, Any] = {"name": self._name}
749-
if self._cindex is not None:
750-
attrs["label_range"] = (self._cindex, self._cindex + self.full_size)
751-
return ds.assign_attrs(attrs)
760+
ds = ds.assign(dual=self._active_to_dataarray(self._dual, fill=np.nan))
761+
return ds.assign_attrs(self.attrs)
752762

753763
def __repr__(self) -> str:
754764
"""Print the constraint without reconstructing the full Dataset."""
@@ -777,7 +787,8 @@ def row_expr(row: int) -> str:
777787
coeffs_row = np.zeros(nterm, dtype=csr.dtype)
778788
vars_row[: end - start] = csr.indices[start:end]
779789
coeffs_row[: end - start] = csr.data[start:end]
780-
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[self._sign]} {self._rhs[row]}"
790+
sign = self._sign if isinstance(self._sign, str) else self._sign[row]
791+
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[sign]} {self._rhs[row]}"
781792

782793
if size > 1:
783794
for indices in generate_indices_for_printout(shape, max_lines):
@@ -819,21 +830,22 @@ def to_netcdf_ds(self) -> Dataset:
819830
"rhs": DataArray(self._rhs, dims=["_flat"]),
820831
"_con_labels": DataArray(self._con_labels, dims=["_flat"]),
821832
}
833+
if isinstance(self._sign, np.ndarray):
834+
data_vars["_sign"] = DataArray(self._sign, dims=["_flat"])
822835
data_vars.update(coords_to_dataset_vars(self._coords))
823836
if self._dual is not None:
824837
data_vars["dual"] = DataArray(self._dual, dims=["_flat"])
825838
dim_names = [c.name for c in self._coords]
826-
return Dataset(
827-
data_vars,
828-
attrs={
829-
"_linopy_format": "csr",
830-
"sign": self._sign,
831-
"cindex": self._cindex if self._cindex is not None else -1,
832-
"shape": list(csr.shape),
833-
"coord_dims": dim_names,
834-
"name": self._name,
835-
},
836-
)
839+
attrs: dict[str, Any] = {
840+
"_linopy_format": "csr",
841+
"cindex": self._cindex if self._cindex is not None else -1,
842+
"shape": list(csr.shape),
843+
"coord_dims": dim_names,
844+
"name": self._name,
845+
}
846+
if isinstance(self._sign, str):
847+
attrs["sign"] = self._sign
848+
return Dataset(data_vars, attrs=attrs)
837849

838850
@classmethod
839851
def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
@@ -845,7 +857,7 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
845857
shape=shape,
846858
)
847859
rhs = ds["rhs"].values
848-
sign = attrs["sign"]
860+
sign: str | np.ndarray = ds["_sign"].values if "_sign" in ds else attrs["sign"]
849861
_cindex_raw = int(attrs["cindex"])
850862
cindex: int | None = _cindex_raw if _cindex_raw >= 0 else None
851863
coord_dims = attrs["coord_dims"]
@@ -873,7 +885,10 @@ def to_matrix_with_rhs(
873885
self, label_index: VariableLabelIndex
874886
) -> tuple[scipy.sparse.csr_array, np.ndarray, np.ndarray, np.ndarray]:
875887
"""Return (csr, con_labels, b, sense) — all pre-stored, no recomputation."""
876-
sense = np.full(len(self._rhs), self._sign[0])
888+
if isinstance(self._sign, str):
889+
sense = np.full(len(self._rhs), self._sign[0])
890+
else:
891+
sense = np.array([s[0] for s in self._sign])
877892
return self._csr, self._con_labels, self._rhs, sense
878893

879894
def sanitize_zeros(self) -> Constraint:
@@ -888,18 +903,25 @@ def sanitize_missings(self) -> Constraint:
888903

889904
def sanitize_infinities(self) -> Constraint:
890905
"""Mask out rows with invalid infinite RHS values (mutates in-place)."""
891-
if self._sign == LESS_EQUAL:
892-
invalid = self._rhs == np.inf
893-
elif self._sign == GREATER_EQUAL:
894-
invalid = self._rhs == -np.inf
906+
if isinstance(self._sign, str):
907+
if self._sign == LESS_EQUAL:
908+
invalid = self._rhs == np.inf
909+
elif self._sign == GREATER_EQUAL:
910+
invalid = self._rhs == -np.inf
911+
else:
912+
return self
895913
else:
896-
return self
914+
invalid = ((self._sign == LESS_EQUAL) & (self._rhs == np.inf)) | (
915+
(self._sign == GREATER_EQUAL) & (self._rhs == -np.inf)
916+
)
897917
if not invalid.any():
898918
return self
899919
keep = ~invalid
900920
self._csr = self._csr[keep]
901921
self._con_labels = self._con_labels[keep]
902922
self._rhs = self._rhs[keep]
923+
if not isinstance(self._sign, str):
924+
self._sign = self._sign[keep]
903925
return self
904926

905927
def freeze(self) -> Constraint:
@@ -939,13 +961,14 @@ def from_mutable(
939961
active_mask = (labels_flat != -1) & (vars_flat != -1).any(axis=1)
940962
rhs = con.rhs.values.ravel()[active_mask]
941963
sign_vals = con.sign.values.ravel()
942-
unique_signs = np.unique(sign_vals[active_mask])
943-
if len(unique_signs) > 1:
944-
raise ValueError(
945-
"Constraint has per-element signs; cannot freeze to immutable Constraint. "
946-
"This is a known limitation — use MutableConstraint instead."
947-
)
948-
sign = str(unique_signs[0]) if len(unique_signs) == 1 else "="
964+
active_signs = sign_vals[active_mask]
965+
unique_signs = np.unique(active_signs)
966+
if len(unique_signs) == 0:
967+
sign: str | np.ndarray = "="
968+
elif len(unique_signs) == 1:
969+
sign = str(unique_signs[0])
970+
else:
971+
sign = active_signs
949972
dual = (
950973
con.data["dual"].values.ravel()[active_mask] if "dual" in con.data else None
951974
)

0 commit comments

Comments
 (0)