@@ -497,9 +497,9 @@ class Constraint(ConstraintBase):
497497 constraint grid (including masked/empty rows).
498498 rhs : np.ndarray
499499 Shape (n_flat,). Right-hand-side values.
500- sign : str
501- Constraint sign: one of '=', '<=', '>='.
502- Note: per-element signs are not supported (documented regression vs MutableConstraint) .
500+ sign : str or np.ndarray
501+ Constraint sign. Either a single str ( '=', '<=', '>=') for uniform
502+ signs, or a per-row np.ndarray of sign strings for mixed signs .
503503 coords : list of pd.Index
504504 One index per coordinate dimension defining the constraint grid.
505505 model : Model
@@ -529,7 +529,7 @@ def __init__(
529529 csr : scipy .sparse .csr_array ,
530530 con_labels : np .ndarray ,
531531 rhs : np .ndarray ,
532- sign : str ,
532+ sign : str | np . ndarray ,
533533 coords : list [pd .Index ],
534534 model : Model ,
535535 name : str = "" ,
@@ -613,16 +613,19 @@ def nterm(self) -> int:
613613 def coord_names (self ) -> list [str ]:
614614 return [str (c .name ) for c in self ._coords ]
615615
616+ def _active_to_dataarray (
617+ self , active_values : np .ndarray , fill : float | int | str = - 1
618+ ) -> DataArray :
619+ full = np .full (self .full_size , fill , dtype = active_values .dtype )
620+ full [self .active_positions ] = active_values
621+ return DataArray (full .reshape (self .shape ), coords = self ._coords )
622+
616623 @property
617624 def labels (self ) -> DataArray :
618625 """Get labels DataArray, shape (*coord_dims)."""
619626 if self ._cindex is None :
620627 return DataArray ([])
621- shape = self .shape
622- full_size = self .full_size
623- labels_flat = np .full (full_size , - 1 , dtype = np .int64 )
624- labels_flat [self .active_positions ] = self ._con_labels
625- return DataArray (labels_flat .reshape (shape ), coords = self ._coords )
628+ return self ._active_to_dataarray (self ._con_labels , fill = - 1 )
626629
627630 @property
628631 def coeffs (self ) -> DataArray :
@@ -648,16 +651,39 @@ def vars(self) -> DataArray:
648651
649652 @property
650653 def sign (self ) -> DataArray :
651- """Get sign DataArray (scalar, same sign for all entries)."""
652- return DataArray (np .full (self .shape , self ._sign ), coords = self ._coords )
654+ """Get sign DataArray."""
655+ if isinstance (self ._sign , str ):
656+ return DataArray (np .full (self .shape , self ._sign ), coords = self ._coords )
657+ return self ._active_to_dataarray (self ._sign , fill = "" )
653658
654659 @property
655660 def rhs (self ) -> DataArray :
656661 """Get RHS DataArray, shape (*coord_dims)."""
657- shape = self .shape
658- rhs_full = np .full (self .full_size , np .nan )
659- rhs_full [self .active_positions ] = self ._rhs
660- return DataArray (rhs_full .reshape (shape ), coords = self ._coords )
662+ return self ._active_to_dataarray (self ._rhs , fill = np .nan )
663+
664+ @rhs .setter
665+ def rhs (self , value : ExpressionLike | VariableLike | ConstantLike ) -> None :
666+ self ._refreeze_after (lambda mc : setattr (mc , "rhs" , value ))
667+
668+ @property
669+ def lhs (self ) -> expressions .LinearExpression :
670+ """Get LHS as LinearExpression (triggers Dataset reconstruction)."""
671+ return self .mutable ().lhs
672+
673+ @lhs .setter
674+ def lhs (self , value : ExpressionLike | VariableLike | ConstantLike ) -> None :
675+ self ._refreeze_after (lambda mc : setattr (mc , "lhs" , value ))
676+
677+ def _refreeze_after (self , mutate : Callable [[MutableConstraint ], None ]) -> None :
678+ mc = self .mutable ()
679+ mutate (mc )
680+ refrozen = Constraint .from_mutable (mc , self ._cindex )
681+ self ._csr = refrozen ._csr
682+ self ._con_labels = refrozen ._con_labels
683+ self ._rhs = refrozen ._rhs
684+ self ._sign = refrozen ._sign
685+ self ._coords = refrozen ._coords
686+ self ._dual = None
661687
662688 @property
663689 @has_optimized_model
@@ -667,9 +693,7 @@ def dual(self) -> DataArray:
667693 raise AttributeError (
668694 "Underlying is optimized but does not have dual values stored."
669695 )
670- dual_full = np .full (self .full_size , np .nan )
671- dual_full [self .active_positions ] = self ._dual
672- return DataArray (dual_full .reshape (self .shape ), coords = self ._coords )
696+ return self ._active_to_dataarray (self ._dual , fill = np .nan )
673697
674698 @dual .setter
675699 def dual (self , value : DataArray ) -> None :
@@ -731,24 +755,10 @@ def _to_dataset(self, nterm: int) -> Dataset:
731755 def data (self ) -> Dataset :
732756 """Reconstruct the xarray Dataset from the CSR representation."""
733757 ds = self ._to_dataset (self .nterm )
734- shape = self .shape
735- active_pos = self .active_positions
736- rhs_full = np .full (self .full_size , np .nan )
737- rhs_full [active_pos ] = self ._rhs
738- ds = ds .assign (
739- sign = DataArray (np .full (shape , self ._sign ), coords = self ._coords ),
740- rhs = DataArray (rhs_full .reshape (shape ), coords = self ._coords ),
741- )
758+ ds = ds .assign (sign = self .sign , rhs = self .rhs )
742759 if self ._dual is not None :
743- dual_full = np .full (self .full_size , np .nan )
744- dual_full [active_pos ] = self ._dual
745- ds = ds .assign (
746- dual = DataArray (dual_full .reshape (shape ), coords = self ._coords )
747- )
748- attrs : dict [str , Any ] = {"name" : self ._name }
749- if self ._cindex is not None :
750- attrs ["label_range" ] = (self ._cindex , self ._cindex + self .full_size )
751- return ds .assign_attrs (attrs )
760+ ds = ds .assign (dual = self ._active_to_dataarray (self ._dual , fill = np .nan ))
761+ return ds .assign_attrs (self .attrs )
752762
753763 def __repr__ (self ) -> str :
754764 """Print the constraint without reconstructing the full Dataset."""
@@ -777,7 +787,8 @@ def row_expr(row: int) -> str:
777787 coeffs_row = np .zeros (nterm , dtype = csr .dtype )
778788 vars_row [: end - start ] = csr .indices [start :end ]
779789 coeffs_row [: end - start ] = csr .data [start :end ]
780- return f"{ print_single_expression (coeffs_row , vars_row , 0 , self ._model )} { SIGNS_pretty [self ._sign ]} { self ._rhs [row ]} "
790+ sign = self ._sign if isinstance (self ._sign , str ) else self ._sign [row ]
791+ return f"{ print_single_expression (coeffs_row , vars_row , 0 , self ._model )} { SIGNS_pretty [sign ]} { self ._rhs [row ]} "
781792
782793 if size > 1 :
783794 for indices in generate_indices_for_printout (shape , max_lines ):
@@ -819,21 +830,22 @@ def to_netcdf_ds(self) -> Dataset:
819830 "rhs" : DataArray (self ._rhs , dims = ["_flat" ]),
820831 "_con_labels" : DataArray (self ._con_labels , dims = ["_flat" ]),
821832 }
833+ if isinstance (self ._sign , np .ndarray ):
834+ data_vars ["_sign" ] = DataArray (self ._sign , dims = ["_flat" ])
822835 data_vars .update (coords_to_dataset_vars (self ._coords ))
823836 if self ._dual is not None :
824837 data_vars ["dual" ] = DataArray (self ._dual , dims = ["_flat" ])
825838 dim_names = [c .name for c in self ._coords ]
826- return Dataset (
827- data_vars ,
828- attrs = {
829- "_linopy_format" : "csr" ,
830- "sign" : self ._sign ,
831- "cindex" : self ._cindex if self ._cindex is not None else - 1 ,
832- "shape" : list (csr .shape ),
833- "coord_dims" : dim_names ,
834- "name" : self ._name ,
835- },
836- )
839+ attrs : dict [str , Any ] = {
840+ "_linopy_format" : "csr" ,
841+ "cindex" : self ._cindex if self ._cindex is not None else - 1 ,
842+ "shape" : list (csr .shape ),
843+ "coord_dims" : dim_names ,
844+ "name" : self ._name ,
845+ }
846+ if isinstance (self ._sign , str ):
847+ attrs ["sign" ] = self ._sign
848+ return Dataset (data_vars , attrs = attrs )
837849
838850 @classmethod
839851 def from_netcdf_ds (cls , ds : Dataset , model : Model , name : str ) -> Constraint :
@@ -845,7 +857,7 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
845857 shape = shape ,
846858 )
847859 rhs = ds ["rhs" ].values
848- sign = attrs ["sign" ]
860+ sign : str | np . ndarray = ds [ "_sign" ]. values if "_sign" in ds else attrs ["sign" ]
849861 _cindex_raw = int (attrs ["cindex" ])
850862 cindex : int | None = _cindex_raw if _cindex_raw >= 0 else None
851863 coord_dims = attrs ["coord_dims" ]
@@ -873,7 +885,10 @@ def to_matrix_with_rhs(
873885 self , label_index : VariableLabelIndex
874886 ) -> tuple [scipy .sparse .csr_array , np .ndarray , np .ndarray , np .ndarray ]:
875887 """Return (csr, con_labels, b, sense) — all pre-stored, no recomputation."""
876- sense = np .full (len (self ._rhs ), self ._sign [0 ])
888+ if isinstance (self ._sign , str ):
889+ sense = np .full (len (self ._rhs ), self ._sign [0 ])
890+ else :
891+ sense = np .array ([s [0 ] for s in self ._sign ])
877892 return self ._csr , self ._con_labels , self ._rhs , sense
878893
879894 def sanitize_zeros (self ) -> Constraint :
@@ -888,18 +903,25 @@ def sanitize_missings(self) -> Constraint:
888903
889904 def sanitize_infinities (self ) -> Constraint :
890905 """Mask out rows with invalid infinite RHS values (mutates in-place)."""
891- if self ._sign == LESS_EQUAL :
892- invalid = self ._rhs == np .inf
893- elif self ._sign == GREATER_EQUAL :
894- invalid = self ._rhs == - np .inf
906+ if isinstance (self ._sign , str ):
907+ if self ._sign == LESS_EQUAL :
908+ invalid = self ._rhs == np .inf
909+ elif self ._sign == GREATER_EQUAL :
910+ invalid = self ._rhs == - np .inf
911+ else :
912+ return self
895913 else :
896- return self
914+ invalid = ((self ._sign == LESS_EQUAL ) & (self ._rhs == np .inf )) | (
915+ (self ._sign == GREATER_EQUAL ) & (self ._rhs == - np .inf )
916+ )
897917 if not invalid .any ():
898918 return self
899919 keep = ~ invalid
900920 self ._csr = self ._csr [keep ]
901921 self ._con_labels = self ._con_labels [keep ]
902922 self ._rhs = self ._rhs [keep ]
923+ if not isinstance (self ._sign , str ):
924+ self ._sign = self ._sign [keep ]
903925 return self
904926
905927 def freeze (self ) -> Constraint :
@@ -939,13 +961,14 @@ def from_mutable(
939961 active_mask = (labels_flat != - 1 ) & (vars_flat != - 1 ).any (axis = 1 )
940962 rhs = con .rhs .values .ravel ()[active_mask ]
941963 sign_vals = con .sign .values .ravel ()
942- unique_signs = np .unique (sign_vals [active_mask ])
943- if len (unique_signs ) > 1 :
944- raise ValueError (
945- "Constraint has per-element signs; cannot freeze to immutable Constraint. "
946- "This is a known limitation — use MutableConstraint instead."
947- )
948- sign = str (unique_signs [0 ]) if len (unique_signs ) == 1 else "="
964+ active_signs = sign_vals [active_mask ]
965+ unique_signs = np .unique (active_signs )
966+ if len (unique_signs ) == 0 :
967+ sign : str | np .ndarray = "="
968+ elif len (unique_signs ) == 1 :
969+ sign = str (unique_signs [0 ])
970+ else :
971+ sign = active_signs
949972 dual = (
950973 con .data ["dual" ].values .ravel ()[active_mask ] if "dual" in con .data else None
951974 )
0 commit comments