Skip to content

Commit 1c57074

Browse files
authored
Fix UxDataset constructor and to_xarray() that got broken with xarray==2026.4.0 (#1492)
* Fix UxDataset constructor and to_xarray() * Remove forgotten line * Address @rajeeja's pointed case * Add test case
1 parent edded0c commit 1c57074

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

test/core/test_dataset.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,16 @@ def test_sel_method_forwarded(gridpath, datasetpath):
9999
nearest["time"].values,
100100
np.array(uxds["time"].values[2], dtype="datetime64[ns]"),
101101
)
102+
103+
def test_uxdataset_init_from_xarray_dataset():
104+
ds = xr.Dataset(
105+
data_vars={"a": ("x", [1, 2])},
106+
coords={"x": [10, 20]},
107+
attrs={"source": "testing"},
108+
)
109+
110+
uxds = ux.UxDataset(ds)
111+
112+
assert "a" in uxds.data_vars
113+
assert "x" in uxds.coords
114+
assert uxds.attrs["source"] == "testing"

uxarray/core/dataset.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,24 @@ def __init__(
9090
else:
9191
self._uxgrid = uxgrid
9292

93+
# As of xarray's 2026.4.0, `xr.Dataset(xr.Dataset)` is prohibited;
94+
# hence this check, i.e. if we get `xr.Dataset` as input, use its `data_vars`
95+
# as `dict` and handle `coords` and `attrs` properly as well
96+
if args and isinstance(args[0], xr.Dataset):
97+
ds = args[0]
98+
# Replacee only args[0], `ds`, with `ds.data_vars` as `dict`
99+
args = (dict(ds.data_vars),) + args[1:]
100+
# coords not passed positionally
101+
if len(args) < 2:
102+
kwargs.setdefault(
103+
"coords", dict(ds.coords)
104+
) # Set it as kwarg only if not explicitly provided
105+
# attrs not passed positionally
106+
if len(args) < 3:
107+
kwargs.setdefault(
108+
"attrs", ds.attrs
109+
) # Set it as kwarg only if not explicitly provided
110+
93111
super().__init__(*args, **kwargs)
94112

95113
# declare plotting accessor
@@ -627,9 +645,9 @@ def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset:
627645
"""
628646
if grid_format == "HEALPix":
629647
ds = self.rename_dims({"n_face": "cell"})
630-
return xr.Dataset(ds)
648+
return xr.Dataset(ds.data_vars, coords=ds.coords, attrs=ds.attrs)
631649

632-
return xr.Dataset(self)
650+
return xr.Dataset(self.data_vars, coords=self.coords, attrs=self.attrs)
633651

634652
def get_dual(self):
635653
"""Compute the dual mesh for a dataset, returns a new dataset object.

0 commit comments

Comments
 (0)