Skip to content

Commit 2021699

Browse files
committed
Use faster numpy operation
1 parent b8f1101 commit 2021699

1 file changed

Lines changed: 21 additions & 9 deletions

File tree

linopy/model.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,10 @@ def add_variables(
552552
if mask is not None:
553553
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)
554554

555-
# Auto-mask based on NaN in bounds
555+
# Auto-mask based on NaN in bounds (use numpy for speed)
556556
if self.auto_mask:
557-
auto_mask_arr = data.lower.notnull() & data.upper.notnull()
557+
auto_mask_values = ~np.isnan(data.lower.values) & ~np.isnan(data.upper.values)
558+
auto_mask_arr = DataArray(auto_mask_values, coords=data.coords, dims=data.dims)
558559
if mask is not None:
559560
mask = mask & auto_mask_arr
560561
else:
@@ -686,9 +687,11 @@ def add_constraints(
686687

687688
# Capture original RHS for auto-masking before constraint creation
688689
# (NaN values in RHS are lost during constraint creation)
689-
original_rhs_notnull = None
690+
# Use numpy for speed instead of xarray's notnull()
691+
original_rhs_mask = None
690692
if self.auto_mask and rhs is not None:
691-
original_rhs_notnull = as_dataarray(rhs).notnull()
693+
rhs_da = as_dataarray(rhs)
694+
original_rhs_mask = (rhs_da.coords, rhs_da.dims, ~np.isnan(rhs_da.values))
692695

693696
if isinstance(lhs, LinearExpression):
694697
if sign is None or rhs is None:
@@ -743,12 +746,21 @@ def add_constraints(
743746
"Dimensions of mask not a subset of resulting labels dimensions."
744747
)
745748

746-
# Auto-mask based on null expressions or NaN RHS
749+
# Auto-mask based on null expressions or NaN RHS (use numpy for speed)
747750
if self.auto_mask:
748-
expr = LinearExpression(data, self)
749-
auto_mask_arr = ~expr.isnull()
750-
if original_rhs_notnull is not None:
751-
auto_mask_arr = auto_mask_arr & original_rhs_notnull
751+
# Check if expression is null: all vars == -1
752+
# This is equivalent to LinearExpression(data, self).isnull() but faster
753+
vars_all_invalid = (data.vars.values == -1).all(axis=-1) # Along TERM_DIM
754+
auto_mask_values = ~vars_all_invalid
755+
if original_rhs_mask is not None:
756+
coords, dims, rhs_notnull = original_rhs_mask
757+
# Broadcast RHS mask to match data shape if needed
758+
if rhs_notnull.shape != auto_mask_values.shape:
759+
rhs_da = DataArray(rhs_notnull, coords=coords, dims=dims)
760+
rhs_da, _ = xr.broadcast(rhs_da, data.labels)
761+
rhs_notnull = rhs_da.values
762+
auto_mask_values = auto_mask_values & rhs_notnull
763+
auto_mask_arr = DataArray(auto_mask_values, coords=data.labels.coords, dims=data.labels.dims)
752764
if mask is not None:
753765
mask = mask & auto_mask_arr
754766
else:

0 commit comments

Comments
 (0)