@@ -529,7 +529,7 @@ def __init__(
529529 csr : scipy .sparse .csr_array ,
530530 con_labels : np .ndarray ,
531531 rhs : np .ndarray ,
532- sign : str ,
532+ sign : str | np . ndarray ,
533533 coords : list [pd .Index ],
534534 model : Model ,
535535 name : str = "" ,
@@ -614,7 +614,7 @@ def coord_names(self) -> list[str]:
614614 return [str (c .name ) for c in self ._coords ]
615615
616616 def _active_to_dataarray (
617- self , active_values : np .ndarray , fill : float | int = - 1
617+ self , active_values : np .ndarray , fill : float | int | str = - 1
618618 ) -> DataArray :
619619 full = np .full (self .full_size , fill , dtype = active_values .dtype )
620620 full [self .active_positions ] = active_values
@@ -651,14 +651,40 @@ def vars(self) -> DataArray:
651651
652652 @property
653653 def sign (self ) -> DataArray :
654- """Get sign DataArray (scalar, same sign for all entries)."""
655- return DataArray (np .full (self .shape , self ._sign ), coords = self ._coords )
654+ """Get sign DataArray."""
655+ if isinstance (self ._sign , str ):
656+ return DataArray (np .full (self .shape , self ._sign ), coords = self ._coords )
657+ return self ._active_to_dataarray (self ._sign , fill = "" )
656658
657659 @property
658660 def rhs (self ) -> DataArray :
659661 """Get RHS DataArray, shape (*coord_dims)."""
660662 return self ._active_to_dataarray (self ._rhs , fill = np .nan )
661663
664+ @rhs .setter
665+ def rhs (self , value : ExpressionLike | VariableLike | ConstantLike ) -> None :
666+ self ._refreeze_after (lambda mc : setattr (mc , "rhs" , value ))
667+
668+ @property
669+ def lhs (self ) -> expressions .LinearExpression :
670+ """Get LHS as LinearExpression (triggers Dataset reconstruction)."""
671+ return self .mutable ().lhs
672+
673+ @lhs .setter
674+ def lhs (self , value : ExpressionLike | VariableLike | ConstantLike ) -> None :
675+ self ._refreeze_after (lambda mc : setattr (mc , "lhs" , value ))
676+
677+ def _refreeze_after (self , mutate : Callable [[MutableConstraint ], None ]) -> None :
678+ mc = self .mutable ()
679+ mutate (mc )
680+ refrozen = Constraint .from_mutable (mc , self ._cindex )
681+ self ._csr = refrozen ._csr
682+ self ._con_labels = refrozen ._con_labels
683+ self ._rhs = refrozen ._rhs
684+ self ._sign = refrozen ._sign
685+ self ._coords = refrozen ._coords
686+ self ._dual = None
687+
662688 @property
663689 @has_optimized_model
664690 def dual (self ) -> DataArray :
@@ -763,7 +789,8 @@ def row_expr(row: int) -> str:
763789 coeffs_row = np .zeros (nterm , dtype = csr .dtype )
764790 vars_row [: end - start ] = vlabels [csr .indices [start :end ]]
765791 coeffs_row [: end - start ] = csr .data [start :end ]
766- return f"{ print_single_expression (coeffs_row , vars_row , 0 , self ._model )} { SIGNS_pretty [self ._sign ]} { self ._rhs [row ]} "
792+ sign = self ._sign if isinstance (self ._sign , str ) else self ._sign [row ]
793+ return f"{ print_single_expression (coeffs_row , vars_row , 0 , self ._model )} { SIGNS_pretty [sign ]} { self ._rhs [row ]} "
767794
768795 if size > 1 :
769796 for indices in generate_indices_for_printout (shape , max_lines ):
@@ -805,21 +832,22 @@ def to_netcdf_ds(self) -> Dataset:
805832 "rhs" : DataArray (self ._rhs , dims = ["_flat" ]),
806833 "_con_labels" : DataArray (self ._con_labels , dims = ["_flat" ]),
807834 }
835+ if isinstance (self ._sign , np .ndarray ):
836+ data_vars ["_sign" ] = DataArray (self ._sign , dims = ["_flat" ])
808837 data_vars .update (coords_to_dataset_vars (self ._coords ))
809838 if self ._dual is not None :
810839 data_vars ["dual" ] = DataArray (self ._dual , dims = ["_flat" ])
811840 dim_names = [c .name for c in self ._coords ]
812- return Dataset (
813- data_vars ,
814- attrs = {
815- "_linopy_format" : "csr" ,
816- "sign" : self ._sign ,
817- "cindex" : self ._cindex if self ._cindex is not None else - 1 ,
818- "shape" : list (csr .shape ),
819- "coord_dims" : dim_names ,
820- "name" : self ._name ,
821- },
822- )
841+ attrs : dict [str , Any ] = {
842+ "_linopy_format" : "csr" ,
843+ "cindex" : self ._cindex if self ._cindex is not None else - 1 ,
844+ "shape" : list (csr .shape ),
845+ "coord_dims" : dim_names ,
846+ "name" : self ._name ,
847+ }
848+ if isinstance (self ._sign , str ):
849+ attrs ["sign" ] = self ._sign
850+ return Dataset (data_vars , attrs = attrs )
823851
824852 @classmethod
825853 def from_netcdf_ds (cls , ds : Dataset , model : Model , name : str ) -> Constraint :
@@ -831,7 +859,9 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
831859 shape = shape ,
832860 )
833861 rhs = ds ["rhs" ].values
834- sign = attrs ["sign" ]
862+ sign : str | np .ndarray = (
863+ ds ["_sign" ].values if "_sign" in ds else attrs ["sign" ]
864+ )
835865 _cindex_raw = int (attrs ["cindex" ])
836866 cindex : int | None = _cindex_raw if _cindex_raw >= 0 else None
837867 coord_dims = attrs ["coord_dims" ]
@@ -859,7 +889,10 @@ def to_matrix_with_rhs(
859889 self , label_index : VariableLabelIndex
860890 ) -> tuple [scipy .sparse .csr_array , np .ndarray , np .ndarray , np .ndarray ]:
861891 """Return (csr, con_labels, b, sense) — all pre-stored, no recomputation."""
862- sense = np .full (len (self ._rhs ), self ._sign [0 ])
892+ if isinstance (self ._sign , str ):
893+ sense = np .full (len (self ._rhs ), self ._sign [0 ])
894+ else :
895+ sense = np .array ([s [0 ] for s in self ._sign ])
863896 return self ._csr , self ._con_labels , self ._rhs , sense
864897
865898 def sanitize_zeros (self ) -> Constraint :
@@ -874,18 +907,25 @@ def sanitize_missings(self) -> Constraint:
874907
875908 def sanitize_infinities (self ) -> Constraint :
876909 """Mask out rows with invalid infinite RHS values (mutates in-place)."""
877- if self ._sign == LESS_EQUAL :
878- invalid = self ._rhs == np .inf
879- elif self ._sign == GREATER_EQUAL :
880- invalid = self ._rhs == - np .inf
910+ if isinstance (self ._sign , str ):
911+ if self ._sign == LESS_EQUAL :
912+ invalid = self ._rhs == np .inf
913+ elif self ._sign == GREATER_EQUAL :
914+ invalid = self ._rhs == - np .inf
915+ else :
916+ return self
881917 else :
882- return self
918+ invalid = ((self ._sign == LESS_EQUAL ) & (self ._rhs == np .inf )) | (
919+ (self ._sign == GREATER_EQUAL ) & (self ._rhs == - np .inf )
920+ )
883921 if not invalid .any ():
884922 return self
885923 keep = ~ invalid
886924 self ._csr = self ._csr [keep ]
887925 self ._con_labels = self ._con_labels [keep ]
888926 self ._rhs = self ._rhs [keep ]
927+ if not isinstance (self ._sign , str ):
928+ self ._sign = self ._sign [keep ]
889929 return self
890930
891931 def freeze (self ) -> Constraint :
@@ -925,13 +965,14 @@ def from_mutable(
925965 active_mask = (labels_flat != - 1 ) & (vars_flat != - 1 ).any (axis = 1 )
926966 rhs = con .rhs .values .ravel ()[active_mask ]
927967 sign_vals = con .sign .values .ravel ()
928- unique_signs = np .unique (sign_vals [active_mask ])
929- if len (unique_signs ) > 1 :
930- raise ValueError (
931- "Constraint has per-element signs; cannot freeze to immutable Constraint. "
932- "This is a known limitation — use MutableConstraint instead."
933- )
934- sign = str (unique_signs [0 ]) if len (unique_signs ) == 1 else "="
968+ active_signs = sign_vals [active_mask ]
969+ unique_signs = np .unique (active_signs )
970+ if len (unique_signs ) == 0 :
971+ sign : str | np .ndarray = "="
972+ elif len (unique_signs ) == 1 :
973+ sign = str (unique_signs [0 ])
974+ else :
975+ sign = active_signs
935976 dual = (
936977 con .data ["dual" ].values .ravel ()[active_mask ] if "dual" in con .data else None
937978 )
0 commit comments