Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,16 @@ def test_sel_method_forwarded(gridpath, datasetpath):
nearest["time"].values,
np.array(uxds["time"].values[2], dtype="datetime64[ns]"),
)

def test_uxdataset_init_from_xarray_dataset():
ds = xr.Dataset(
data_vars={"a": ("x", [1, 2])},
coords={"x": [10, 20]},
attrs={"source": "testing"},
)

uxds = ux.UxDataset(ds)

assert "a" in uxds.data_vars
assert "x" in uxds.coords
assert uxds.attrs["source"] == "testing"
22 changes: 20 additions & 2 deletions uxarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ def __init__(
else:
self._uxgrid = uxgrid

# As of xarray's 2026.4.0, `xr.Dataset(xr.Dataset)` is prohibited;
# hence this check, i.e. if we get `xr.Dataset` as input, use its `data_vars`
# as `dict` and handle `coords` and `attrs` properly as well
if args and isinstance(args[0], xr.Dataset):
ds = args[0]
Comment thread
erogluorhan marked this conversation as resolved.
# Replacee only args[0], `ds`, with `ds.data_vars` as `dict`
args = (dict(ds.data_vars),) + args[1:]
# coords not passed positionally
if len(args) < 2:
kwargs.setdefault(
"coords", dict(ds.coords)
) # Set it as kwarg only if not explicitly provided
# attrs not passed positionally
if len(args) < 3:
kwargs.setdefault(
"attrs", ds.attrs
) # Set it as kwarg only if not explicitly provided

super().__init__(*args, **kwargs)

# declare plotting accessor
Expand Down Expand Up @@ -627,9 +645,9 @@ def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset:
"""
if grid_format == "HEALPix":
ds = self.rename_dims({"n_face": "cell"})
return xr.Dataset(ds)
return xr.Dataset(ds.data_vars, coords=ds.coords, attrs=ds.attrs)

return xr.Dataset(self)
return xr.Dataset(self.data_vars, coords=self.coords, attrs=self.attrs)

def get_dual(self):
"""Compute the dual mesh for a dataset, returns a new dataset object.
Expand Down
Loading