Skip to content

Commit 75c442c

Browse files
Reinsert broadcasted mask (#580)
* reinsert broadcasting of masks * update release notes * consolidate broadcast mask into new function, add tests for subsets * align test logic to broadcasting * Reinsert broadcasted mask (#581) * 1. Moved the dimension subset check into broadcast_mask 2. Added a brief docstring to broadcast_mask * Add tests for superset dims --------- Co-authored-by: FBumann <117816358+FBumann@users.noreply.github.com>
1 parent ec6262b commit 75c442c

File tree

6 files changed

+71
-82
lines changed

6 files changed

+71
-82
lines changed

doc/release_notes.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ Release Notes
44
Upcoming Version
55
----------------
66

7+
**Fix Regression**
8+
9+
* Reinsert broadcasting logic of mask object to be fully compatible with performance improvements in version 0.6.2 using `np.where` instead of `xr.where`.
10+
11+
712
Version 0.6.2
813
--------------
914

linopy/common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,32 @@ def as_dataarray(
286286
return arr
287287

288288

289+
def broadcast_mask(mask: DataArray, labels: DataArray) -> DataArray:
290+
"""
291+
Broadcast a boolean mask to match the shape of labels.
292+
293+
Ensures that mask dimensions are a subset of labels dimensions, broadcasts
294+
the mask accordingly, and fills any NaN values (from missing coordinates)
295+
with False while emitting a FutureWarning.
296+
"""
297+
assert set(mask.dims).issubset(labels.dims), (
298+
"Dimensions of mask not a subset of resulting labels dimensions."
299+
)
300+
mask = mask.broadcast_like(labels)
301+
if mask.isnull().any():
302+
warn(
303+
"Mask contains coordinates not covered by the data dimensions. "
304+
"Missing values will be filled with False (masked out). "
305+
"In a future version, this will raise an error. "
306+
"Use mask.reindex() or `linopy.align()` to explicitly handle missing "
307+
"coordinates.",
308+
FutureWarning,
309+
stacklevel=3,
310+
)
311+
mask = mask.fillna(False).astype(bool)
312+
return mask
313+
314+
289315
# TODO: rename to to_pandas_dataframe
290316
def to_dataframe(
291317
ds: Dataset,

linopy/model.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import logging
1010
import os
1111
import re
12-
import warnings
1312
from collections.abc import Callable, Mapping, Sequence
1413
from pathlib import Path
1514
from tempfile import NamedTemporaryFile, gettempdir
@@ -30,6 +29,7 @@
3029
as_dataarray,
3130
assign_multiindex_safe,
3231
best_int,
32+
broadcast_mask,
3333
maybe_replace_signs,
3434
replace_by_map,
3535
set_int_index,
@@ -552,16 +552,7 @@ def add_variables(
552552

553553
if mask is not None:
554554
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)
555-
if set(mask.dims) != set(data["labels"].dims):
556-
warnings.warn(
557-
f"Mask dimensions {set(mask.dims)} do not match the data "
558-
f"dimensions {set(data['labels'].dims)}. The mask will be "
559-
f"broadcast across the missing dimensions "
560-
f"{set(data['labels'].dims) - set(mask.dims)}. In a future "
561-
"version, this will raise an error.",
562-
FutureWarning,
563-
stacklevel=2,
564-
)
555+
mask = broadcast_mask(mask, data.labels)
565556

566557
# Auto-mask based on NaN in bounds (use numpy for speed)
567558
if self.auto_mask:
@@ -582,7 +573,7 @@ def add_variables(
582573
self._xCounter += data.labels.size
583574

584575
if mask is not None:
585-
data.labels.values = data.labels.where(mask, -1).values
576+
data.labels.values = np.where(mask.values, data.labels.values, -1)
586577

587578
data = data.assign_attrs(
588579
label_range=(start, end), name=name, binary=binary, integer=integer
@@ -756,20 +747,7 @@ def add_constraints(
756747

757748
if mask is not None:
758749
mask = as_dataarray(mask).astype(bool)
759-
# TODO: simplify
760-
assert set(mask.dims).issubset(data.dims), (
761-
"Dimensions of mask not a subset of resulting labels dimensions."
762-
)
763-
if set(mask.dims) != set(data["labels"].dims):
764-
warnings.warn(
765-
f"Mask dimensions {set(mask.dims)} do not match the data "
766-
f"dimensions {set(data['labels'].dims)}. The mask will be "
767-
f"broadcast across the missing dimensions "
768-
f"{set(data['labels'].dims) - set(mask.dims)}. In a future "
769-
"version, this will raise an error.",
770-
FutureWarning,
771-
stacklevel=2,
772-
)
750+
mask = broadcast_mask(mask, data.labels)
773751

774752
# Auto-mask based on null expressions or NaN RHS (use numpy for speed)
775753
if self.auto_mask:
@@ -780,11 +758,9 @@ def add_constraints(
780758
auto_mask_values = ~vars_all_invalid
781759
if original_rhs_mask is not None:
782760
coords, dims, rhs_notnull = original_rhs_mask
783-
# Broadcast RHS mask to match data shape if needed
784761
if rhs_notnull.shape != auto_mask_values.shape:
785762
rhs_da = DataArray(rhs_notnull, coords=coords, dims=dims)
786-
rhs_da, _ = xr.broadcast(rhs_da, data.labels)
787-
rhs_notnull = rhs_da.values
763+
rhs_notnull = rhs_da.broadcast_like(data.labels).values
788764
auto_mask_values = auto_mask_values & rhs_notnull
789765
auto_mask_arr = DataArray(
790766
auto_mask_values, coords=data.labels.coords, dims=data.labels.dims
@@ -802,7 +778,7 @@ def add_constraints(
802778
self._cCounter += data.labels.size
803779

804780
if mask is not None:
805-
data.labels.values = data.labels.where(mask, -1).values
781+
data.labels.values = np.where(mask.values, data.labels.values, -1)
806782

807783
data = data.assign_attrs(label_range=(start, end), name=name)
808784

test/test_constraints.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -157,54 +157,44 @@ def test_masked_constraints() -> None:
157157
y = m.add_variables()
158158

159159
mask = pd.Series([True] * 5 + [False] * 5)
160-
with pytest.warns(FutureWarning, match="Mask dimensions"):
161-
m.add_constraints(1 * x + 10 * y, EQUAL, 0, mask=mask)
160+
m.add_constraints(1 * x + 10 * y, EQUAL, 0, mask=mask)
162161
assert (m.constraints.labels.con0[0:5, :] != -1).all()
163162
assert (m.constraints.labels.con0[5:10, :] == -1).all()
164163

165164

166165
def test_masked_constraints_broadcast() -> None:
167-
"""Test that a constraint mask with fewer dimensions broadcasts correctly."""
168166
m: Model = Model()
169167

170168
lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)])
171169
upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)])
172170
x = m.add_variables(lower, upper)
173171
y = m.add_variables()
174172

175-
# 1D mask applied to 2D constraint — must broadcast over second dim
176173
mask = pd.Series([True] * 5 + [False] * 5)
177-
with pytest.warns(FutureWarning, match="Mask dimensions"):
178-
m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc1", mask=mask)
174+
m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc1", mask=mask)
179175
assert (m.constraints.labels.bc1[0:5, :] != -1).all()
180176
assert (m.constraints.labels.bc1[5:10, :] == -1).all()
181177

182-
# Mask along second dimension only
183178
mask2 = xr.DataArray([True] * 5 + [False] * 5, dims=["dim_1"])
184-
with pytest.warns(FutureWarning, match="Mask dimensions"):
185-
m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc2", mask=mask2)
179+
m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc2", mask=mask2)
186180
assert (m.constraints.labels.bc2[:, 0:5] != -1).all()
187181
assert (m.constraints.labels.bc2[:, 5:10] == -1).all()
188182

189-
190-
def test_constraints_mask_no_warning_when_aligned() -> None:
191-
"""Test that no FutureWarning is emitted when mask has same dims as data."""
192-
m: Model = Model()
193-
194-
lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)])
195-
upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)])
196-
x = m.add_variables(lower, upper)
197-
y = m.add_variables()
198-
199-
mask = xr.DataArray(
200-
np.array([[True] * 10] * 5 + [[False] * 10] * 5),
201-
coords=[range(10), range(10)],
183+
mask3 = xr.DataArray(
184+
[True, True, False, False, False],
185+
dims=["dim_0"],
186+
coords={"dim_0": range(5)},
202187
)
203-
import warnings
204-
205-
with warnings.catch_warnings():
206-
warnings.simplefilter("error", FutureWarning)
207-
m.add_constraints(1 * x + 10 * y, EQUAL, 0, mask=mask)
188+
with pytest.warns(FutureWarning, match="Missing values will be filled"):
189+
m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc3", mask=mask3)
190+
assert (m.constraints.labels.bc3[0:2, :] != -1).all()
191+
assert (m.constraints.labels.bc3[2:5, :] == -1).all()
192+
assert (m.constraints.labels.bc3[5:10, :] == -1).all()
193+
194+
# Mask with extra dimension not in data should raise
195+
mask4 = xr.DataArray([True, False], dims=["extra_dim"])
196+
with pytest.raises(AssertionError, match="not a subset"):
197+
m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc4", mask=mask4)
208198

209199

210200
def test_non_aligned_constraints() -> None:

test/test_variable_assignment.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,7 @@ def test_variable_assigment_masked() -> None:
227227
lower = pd.DataFrame(np.zeros((10, 10)))
228228
upper = pd.Series(np.ones(10))
229229
mask = pd.Series([True] * 5 + [False] * 5)
230-
with pytest.warns(FutureWarning, match="Mask dimensions"):
231-
m.add_variables(lower, upper, mask=mask)
230+
m.add_variables(lower, upper, mask=mask)
232231
assert m.variables.labels.var0[-1, -1].item() == -1
233232

234233

test/test_variables.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,43 +108,36 @@ def test_variables_nvars(m: Model) -> None:
108108

109109

110110
def test_variables_mask_broadcast() -> None:
111-
"""Test that a mask with fewer dimensions broadcasts correctly."""
112111
m = Model()
113112

114113
lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)])
115114
upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)])
116115

117-
# 1D mask applied to 2D variable — must broadcast over second dim
118116
mask = pd.Series([True] * 5 + [False] * 5)
119-
with pytest.warns(FutureWarning, match="Mask dimensions"):
120-
x = m.add_variables(lower, upper, name="x", mask=mask)
117+
x = m.add_variables(lower, upper, name="x", mask=mask)
121118
assert (x.labels[0:5, :] != -1).all()
122119
assert (x.labels[5:10, :] == -1).all()
123120

124-
# Mask along second dimension only
125121
mask2 = xr.DataArray([True] * 5 + [False] * 5, dims=["dim_1"])
126-
with pytest.warns(FutureWarning, match="Mask dimensions"):
127-
y = m.add_variables(lower, upper, name="y", mask=mask2)
122+
y = m.add_variables(lower, upper, name="y", mask=mask2)
128123
assert (y.labels[:, 0:5] != -1).all()
129124
assert (y.labels[:, 5:10] == -1).all()
130125

131-
132-
def test_variables_mask_no_warning_when_aligned() -> None:
133-
"""Test that no FutureWarning is emitted when mask has same dims as data."""
134-
m = Model()
135-
136-
lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)])
137-
upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)])
138-
139-
mask = xr.DataArray(
140-
np.array([[True] * 10] * 5 + [[False] * 10] * 5),
141-
coords=[range(10), range(10)],
126+
mask3 = xr.DataArray(
127+
[True, True, False, False, False],
128+
dims=["dim_0"],
129+
coords={"dim_0": range(5)},
142130
)
143-
import warnings
144-
145-
with warnings.catch_warnings():
146-
warnings.simplefilter("error", FutureWarning)
147-
m.add_variables(lower, upper, name="x", mask=mask)
131+
with pytest.warns(FutureWarning, match="Missing values will be filled"):
132+
z = m.add_variables(lower, upper, name="z", mask=mask3)
133+
assert (z.labels[0:2, :] != -1).all()
134+
assert (z.labels[2:5, :] == -1).all()
135+
assert (z.labels[5:10, :] == -1).all()
136+
137+
# Mask with extra dimension not in data should raise
138+
mask4 = xr.DataArray([True, False], dims=["extra_dim"])
139+
with pytest.raises(AssertionError, match="not a subset"):
140+
m.add_variables(lower, upper, name="w", mask=mask4)
148141

149142

150143
def test_variables_get_name_by_label(m: Model) -> None:

0 commit comments

Comments
 (0)