Skip to content

Commit 9508117

Browse files
FBumannclaude
andcommitted
Fix DataArray validation for upstream compatibility
- Skip coord validation for DataArray inputs in arithmetic contexts (allow_extra_dims=True) to preserve xarray's native alignment - Add allow_extra_dims=True to comparison operator and quadratic dot as_dataarray calls for consistent broadcasting - Handle MultiIndex levels in expand_dims guard Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent be93663 commit 9508117

2 files changed

Lines changed: 17 additions & 10 deletions

File tree

linopy/common.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,19 @@ def _validate_dataarray_coords(
172172
expand = {}
173173
for k, v in expected.items():
174174
if k not in arr.dims:
175+
# Skip coords that already exist (e.g. MultiIndex levels)
176+
if k in arr.coords:
177+
continue
175178
expand[k] = v
176179
continue
177-
actual = arr.coords[k]
178-
v_idx = v if isinstance(v, pd.Index) else pd.Index(v)
179-
if not actual.to_index().equals(v_idx):
180-
raise ValueError(
181-
f"Coordinates for dimension '{k}' do not match: "
182-
f"expected {v_idx.tolist()}, got {actual.values.tolist()}"
183-
)
180+
if not allow_extra_dims:
181+
actual = arr.coords[k]
182+
v_idx = v if isinstance(v, pd.Index) else pd.Index(v)
183+
if not actual.to_index().equals(v_idx):
184+
raise ValueError(
185+
f"Coordinates for dimension '{k}' do not match: "
186+
f"expected {v_idx.tolist()}, got {actual.values.tolist()}"
187+
)
184188

185189
if expand:
186190
arr = arr.expand_dims(expand)
@@ -240,7 +244,6 @@ def pandas_to_dataarray(
240244
" for alignment."
241245
)
242246

243-
244247
return DataArray(arr, coords=None, dims=dims, **kwargs)
245248

246249

linopy/expressions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,9 @@ def to_constraint(
11461146
)
11471147

11481148
if isinstance(rhs, SUPPORTED_CONSTANT_TYPES):
1149-
rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims)
1149+
rhs = as_dataarray(
1150+
rhs, coords=self.coords, dims=self.coord_dims, allow_extra_dims=True
1151+
)
11501152

11511153
extra_dims = set(rhs.dims) - set(self.coord_dims)
11521154
if extra_dims:
@@ -2197,7 +2199,9 @@ def __matmul__(
21972199
"Higher order non-linear expressions are not yet supported."
21982200
)
21992201

2200-
other = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
2202+
other = as_dataarray(
2203+
other, coords=self.coords, dims=self.coord_dims, allow_extra_dims=True
2204+
)
22012205
common_dims = list(set(self.coord_dims).intersection(other.dims))
22022206
return (self * other).sum(dim=common_dims)
22032207

0 commit comments

Comments
 (0)