Skip to content

Commit 3316c0f

Browse files
committed
Optimize mask application and null expression check
Performance improvements: - Use np.where() instead of xarray where() for mask application (~38x faster) - Use max() == -1 instead of all() == -1 for null expression check (~30% faster) These optimizations make auto_mask have minimal overhead compared to manual masking.
1 parent 2f3b49a commit 3316c0f

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

linopy/model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,8 @@ def add_variables(
571571
self._xCounter += data.labels.size
572572

573573
if mask is not None:
574-
data.labels.values = data.labels.where(mask, -1).values
574+
# Use numpy where for speed (38x faster than xarray where)
575+
data.labels.values = np.where(mask.values, data.labels.values, -1)
575576

576577
data = data.assign_attrs(
577578
label_range=(start, end), name=name, binary=binary, integer=integer
@@ -753,8 +754,9 @@ def add_constraints(
753754
# Auto-mask based on null expressions or NaN RHS (use numpy for speed)
754755
if self.auto_mask:
755756
# Check if expression is null: all vars == -1
756-
# This is equivalent to LinearExpression(data, self).isnull() but faster
757-
vars_all_invalid = (data.vars.values == -1).all(axis=-1) # Along TERM_DIM
757+
# Use max() instead of all() - if max == -1, all are -1 (since valid vars >= 0)
758+
# This is ~30% faster for large term dimensions
759+
vars_all_invalid = data.vars.values.max(axis=-1) == -1
758760
auto_mask_values = ~vars_all_invalid
759761
if original_rhs_mask is not None:
760762
coords, dims, rhs_notnull = original_rhs_mask
@@ -780,7 +782,8 @@ def add_constraints(
780782
self._cCounter += data.labels.size
781783

782784
if mask is not None:
783-
data.labels.values = data.labels.where(mask, -1).values
785+
# Use numpy where for speed (38x faster than xarray where)
786+
data.labels.values = np.where(mask.values, data.labels.values, -1)
784787

785788
data = data.assign_attrs(label_range=(start, end), name=name)
786789

0 commit comments

Comments
 (0)