@@ -511,7 +511,7 @@ def __init__(
511511 csr : scipy .sparse .csr_array ,
512512 con_labels : np .ndarray ,
513513 rhs : np .ndarray ,
514- sign : str ,
514+ sign : str | np . ndarray ,
515515 coords : list [pd .Index ],
516516 model : Model ,
517517 name : str = "" ,
@@ -596,7 +596,7 @@ def coord_names(self) -> list[str]:
596596 return [c .name for c in self ._coords ]
597597
598598 def _active_to_dataarray (
599- self , active_values : np .ndarray , fill : float | int = - 1
599+ self , active_values : np .ndarray , fill : float | int | str = - 1
600600 ) -> DataArray :
601601 full = np .full (self .full_size , fill , dtype = active_values .dtype )
602602 full [self .active_positions ] = active_values
@@ -633,14 +633,40 @@ def vars(self) -> DataArray:
633633
634634 @property
635635 def sign (self ) -> DataArray :
636- """Get sign DataArray (scalar, same sign for all entries)."""
637- return DataArray (np .full (self .shape , self ._sign ), coords = self ._coords )
636+ """Get sign DataArray."""
637+ if isinstance (self ._sign , str ):
638+ return DataArray (np .full (self .shape , self ._sign ), coords = self ._coords )
639+ return self ._active_to_dataarray (self ._sign , fill = "" )
638640
639641 @property
640642 def rhs (self ) -> DataArray :
641643 """Get RHS DataArray, shape (*coord_dims)."""
642644 return self ._active_to_dataarray (self ._rhs , fill = np .nan )
643645
646+ @rhs .setter
647+ def rhs (self , value : ExpressionLike | VariableLike | ConstantLike ) -> None :
648+ self ._refreeze_after (lambda mc : setattr (mc , "rhs" , value ))
649+
650+ @property
651+ def lhs (self ) -> expressions .LinearExpression :
652+ """Get LHS as LinearExpression (triggers Dataset reconstruction)."""
653+ return self .mutable ().lhs
654+
655+ @lhs .setter
656+ def lhs (self , value : ExpressionLike | VariableLike | ConstantLike ) -> None :
657+ self ._refreeze_after (lambda mc : setattr (mc , "lhs" , value ))
658+
659+ def _refreeze_after (self , mutate : Callable [[MutableConstraint ], None ]) -> None :
660+ mc = self .mutable ()
661+ mutate (mc )
662+ refrozen = Constraint .from_mutable (mc , self ._cindex )
663+ self ._csr = refrozen ._csr
664+ self ._con_labels = refrozen ._con_labels
665+ self ._rhs = refrozen ._rhs
666+ self ._sign = refrozen ._sign
667+ self ._coords = refrozen ._coords
668+ self ._dual = None
669+
644670 @property
645671 @has_optimized_model
646672 def dual (self ) -> DataArray :
@@ -745,7 +771,8 @@ def row_expr(row: int) -> str:
745771 coeffs_row = np .zeros (nterm , dtype = csr .dtype )
746772 vars_row [: end - start ] = vlabels [csr .indices [start :end ]]
747773 coeffs_row [: end - start ] = csr .data [start :end ]
748- return f"{ print_single_expression (coeffs_row , vars_row , 0 , self ._model )} { SIGNS_pretty [self ._sign ]} { self ._rhs [row ]} "
774+ sign = self ._sign if isinstance (self ._sign , str ) else self ._sign [row ]
775+ return f"{ print_single_expression (coeffs_row , vars_row , 0 , self ._model )} { SIGNS_pretty [sign ]} { self ._rhs [row ]} "
749776
750777 if size > 1 :
751778 for indices in generate_indices_for_printout (shape , max_lines ):
@@ -787,21 +814,22 @@ def to_netcdf_ds(self) -> Dataset:
787814 "rhs" : DataArray (self ._rhs , dims = ["_flat" ]),
788815 "_con_labels" : DataArray (self ._con_labels , dims = ["_flat" ]),
789816 }
817+ if isinstance (self ._sign , np .ndarray ):
818+ data_vars ["_sign" ] = DataArray (self ._sign , dims = ["_flat" ])
790819 data_vars .update (coords_to_dataset_vars (self ._coords ))
791820 if self ._dual is not None :
792821 data_vars ["dual" ] = DataArray (self ._dual , dims = ["_flat" ])
793822 dim_names = [c .name for c in self ._coords ]
794- return Dataset (
795- data_vars ,
796- attrs = {
797- "_linopy_format" : "csr" ,
798- "sign" : self ._sign ,
799- "cindex" : self ._cindex if self ._cindex is not None else - 1 ,
800- "shape" : list (csr .shape ),
801- "coord_dims" : dim_names ,
802- "name" : self ._name ,
803- },
804- )
823+ attrs : dict [str , Any ] = {
824+ "_linopy_format" : "csr" ,
825+ "cindex" : self ._cindex if self ._cindex is not None else - 1 ,
826+ "shape" : list (csr .shape ),
827+ "coord_dims" : dim_names ,
828+ "name" : self ._name ,
829+ }
830+ if isinstance (self ._sign , str ):
831+ attrs ["sign" ] = self ._sign
832+ return Dataset (data_vars , attrs = attrs )
805833
806834 @classmethod
807835 def from_netcdf_ds (cls , ds : Dataset , model : Model , name : str ) -> Constraint :
@@ -813,7 +841,7 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
813841 shape = shape ,
814842 )
815843 rhs = ds ["rhs" ].values
816- sign = attrs ["sign" ]
844+ sign : str | np . ndarray = ds [ "_sign" ]. values if "_sign" in ds else attrs ["sign" ]
817845 cindex = int (attrs ["cindex" ])
818846 cindex = cindex if cindex >= 0 else None
819847 coord_dims = attrs ["coord_dims" ]
@@ -841,7 +869,10 @@ def to_matrix_with_rhs(
841869 self , label_index : VariableLabelIndex
842870 ) -> tuple [scipy .sparse .csr_array , np .ndarray , np .ndarray , np .ndarray ]:
843871 """Return (csr, con_labels, b, sense) — all pre-stored, no recomputation."""
844- sense = np .full (len (self ._rhs ), self ._sign [0 ])
872+ if isinstance (self ._sign , str ):
873+ sense = np .full (len (self ._rhs ), self ._sign [0 ])
874+ else :
875+ sense = np .array ([s [0 ] for s in self ._sign ])
845876 return self ._csr , self ._con_labels , self ._rhs , sense
846877
847878 def sanitize_zeros (self ) -> Constraint :
@@ -856,18 +887,25 @@ def sanitize_missings(self) -> Constraint:
856887
857888 def sanitize_infinities (self ) -> Constraint :
858889 """Mask out rows with invalid infinite RHS values (mutates in-place)."""
859- if self ._sign == LESS_EQUAL :
860- invalid = self ._rhs == np .inf
861- elif self ._sign == GREATER_EQUAL :
862- invalid = self ._rhs == - np .inf
890+ if isinstance (self ._sign , str ):
891+ if self ._sign == LESS_EQUAL :
892+ invalid = self ._rhs == np .inf
893+ elif self ._sign == GREATER_EQUAL :
894+ invalid = self ._rhs == - np .inf
895+ else :
896+ return self
863897 else :
864- return self
898+ invalid = ((self ._sign == LESS_EQUAL ) & (self ._rhs == np .inf )) | (
899+ (self ._sign == GREATER_EQUAL ) & (self ._rhs == - np .inf )
900+ )
865901 if not invalid .any ():
866902 return self
867903 keep = ~ invalid
868904 self ._csr = self ._csr [keep ]
869905 self ._con_labels = self ._con_labels [keep ]
870906 self ._rhs = self ._rhs [keep ]
907+ if not isinstance (self ._sign , str ):
908+ self ._sign = self ._sign [keep ]
871909 return self
872910
873911 def freeze (self ) -> Constraint :
@@ -907,13 +945,14 @@ def from_mutable(
907945 active_mask = (labels_flat != - 1 ) & (vars_flat != - 1 ).any (axis = 1 )
908946 rhs = con .rhs .values .ravel ()[active_mask ]
909947 sign_vals = con .sign .values .ravel ()
910- unique_signs = np .unique (sign_vals [active_mask ])
911- if len (unique_signs ) > 1 :
912- raise ValueError (
913- "Constraint has per-element signs; cannot freeze to immutable Constraint. "
914- "This is a known limitation — use MutableConstraint instead."
915- )
916- sign = str (unique_signs [0 ]) if len (unique_signs ) == 1 else "="
948+ active_signs = sign_vals [active_mask ]
949+ unique_signs = np .unique (active_signs )
950+ if len (unique_signs ) == 0 :
951+ sign : str | np .ndarray = "="
952+ elif len (unique_signs ) == 1 :
953+ sign = str (unique_signs [0 ])
954+ else :
955+ sign = active_signs
917956 dual = (
918957 con .data ["dual" ].values .ravel ()[active_mask ] if "dual" in con .data else None
919958 )
0 commit comments