Skip to content

Commit 3f5d866

Browse files
FBumannclaude
andcommitted
Fix DataArray validation for upstream compatibility
- Skip coord validation for DataArray inputs in arithmetic contexts (allow_extra_dims=True) to preserve xarray's native alignment - Add allow_extra_dims=True to comparison operator and quadratic dot as_dataarray calls for consistent broadcasting - Handle MultiIndex levels in expand_dims guard Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent be93663 commit 3f5d866

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

linopy/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def _validate_dataarray_coords(
172172
expand = {}
173173
for k, v in expected.items():
174174
if k not in arr.dims:
175+
# Skip coords that already exist (e.g. MultiIndex levels)
176+
if k in arr.coords:
177+
continue
175178
expand[k] = v
176179
continue
177180
actual = arr.coords[k]
@@ -240,7 +243,6 @@ def pandas_to_dataarray(
240243
" for alignment."
241244
)
242245

243-
244246
return DataArray(arr, coords=None, dims=dims, **kwargs)
245247

246248

@@ -341,7 +343,7 @@ def as_dataarray(
341343
elif isinstance(arr, int | float | str | bool | list):
342344
arr = DataArray(arr, coords=coords, dims=dims, **kwargs)
343345
elif isinstance(arr, DataArray):
344-
if coords is not None:
346+
if coords is not None and not allow_extra_dims:
345347
arr = _validate_dataarray_coords(arr, coords, dims, allow_extra_dims)
346348
elif not isinstance(arr, DataArray):
347349
supported_types = [

linopy/expressions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,9 @@ def to_constraint(
11461146
)
11471147

11481148
if isinstance(rhs, SUPPORTED_CONSTANT_TYPES):
1149-
rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims)
1149+
rhs = as_dataarray(
1150+
rhs, coords=self.coords, dims=self.coord_dims, allow_extra_dims=True
1151+
)
11501152

11511153
extra_dims = set(rhs.dims) - set(self.coord_dims)
11521154
if extra_dims:
@@ -2197,7 +2199,9 @@ def __matmul__(
21972199
"Higher order non-linear expressions are not yet supported."
21982200
)
21992201

2200-
other = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
2202+
other = as_dataarray(
2203+
other, coords=self.coords, dims=self.coord_dims, allow_extra_dims=True
2204+
)
22012205
common_dims = list(set(self.coord_dims).intersection(other.dims))
22022206
return (self * other).sum(dim=common_dims)
22032207

0 commit comments

Comments
 (0)