diff --git a/doc/api/io.rst b/doc/api/io.rst index fffc7785800..62446324e3c 100644 --- a/doc/api/io.rst +++ b/doc/api/io.rst @@ -70,6 +70,7 @@ DataTree methods .. autosummary:: :toctree: ../generated/ + load_datatree open_datatree open_groups DataTree.to_dict diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 38bff514c64..a46dba9f15a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,10 @@ v2025.08.1 (unreleased) New Features ~~~~~~~~~~~~ +- Added :py:func:`load_datatree` for loading ``DataTree`` objects into memory + from disk. It has the same relationship to :py:func:`open_datatree`, as + :py:func:`load_dataset` has to :py:func:`open_dataset`. + By `Stephan Hoyer `_. - ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and :py:meth:`DataTree.to_zarr`. By `Stephan Hoyer `_. diff --git a/xarray/__init__.py b/xarray/__init__.py index 04fb5b03867..7901fffcbed 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -4,6 +4,7 @@ from xarray.backends.api import ( load_dataarray, load_dataset, + load_datatree, open_dataarray, open_dataset, open_datatree, @@ -96,6 +97,7 @@ "infer_freq", "load_dataarray", "load_dataset", + "load_datatree", "map_blocks", "map_over_datasets", "merge", diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 86adfc8b7ce..afe840d946f 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -328,7 +328,7 @@ def _multi_file_closer(closers): closer() -def load_dataset(filename_or_obj, **kwargs) -> Dataset: +def load_dataset(filename_or_obj: T_PathFileOrDataStore, **kwargs) -> Dataset: """Open, load into memory, and close a Dataset from a file or file-like object. @@ -354,7 +354,7 @@ def load_dataset(filename_or_obj, **kwargs) -> Dataset: return ds.load() -def load_dataarray(filename_or_obj, **kwargs): +def load_dataarray(filename_or_obj: T_PathFileOrDataStore, **kwargs) -> DataArray: """Open, load into memory, and close a DataArray from a file or file-like object containing a single data variable. @@ -380,6 +380,32 @@ def load_dataarray(filename_or_obj, **kwargs): return da.load() +def load_datatree(filename_or_obj: T_PathFileOrDataStore, **kwargs) -> DataTree: + """Open, load into memory, and close a DataTree from a file or file-like + object. + + This is a thin wrapper around :py:meth:`~xarray.open_datatree`. It differs + from `open_datatree` in that it loads the DataTree into memory, closes the + file, and returns the DataTree. In contrast, `open_datatree` keeps the file + handle open and lazy loads its contents. All parameters are passed directly + to `open_datatree`. See that documentation for further details. + + Returns + ------- + datatree : DataTree + The newly created DataTree. + + See Also + -------- + open_datatree + """ + if "cache" in kwargs: + raise TypeError("cache has no effect in this context") + + with open_datatree(filename_or_obj, **kwargs) as dt: + return dt.load() + + def _chunk_ds( backend_ds, filename_or_obj, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a8cbb6620d1..3a359cca1ae 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -34,9 +34,11 @@ from xarray import ( DataArray, Dataset, + DataTree, backends, load_dataarray, load_dataset, + load_datatree, open_dataarray, open_dataset, open_mfdataset, @@ -6117,17 +6119,29 @@ def test_load_dataset(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) original.to_netcdf(tmp) ds = load_dataset(tmp) + assert_identical(original, ds) # this would fail if we used open_dataset instead of load_dataset ds.to_netcdf(tmp) def test_load_dataarray(self) -> None: with create_tmp_file() as tmp: - original = Dataset({"foo": ("x", np.random.randn(10))}) + original = DataArray(np.random.randn(10), dims=["x"]) original.to_netcdf(tmp) - ds = load_dataarray(tmp) + da = load_dataarray(tmp) + assert_identical(original, da) # this would fail if we used open_dataarray instead of # load_dataarray - ds.to_netcdf(tmp) + da.to_netcdf(tmp) + + def test_load_datatree(self) -> None: + with create_tmp_file() as tmp: + original = DataTree(Dataset({"foo": ("x", np.random.randn(10))})) + original.to_netcdf(tmp) + dt = load_datatree(tmp) + xr.testing.assert_identical(original, dt) + # this would fail if we used open_datatree instead of + # load_datatree + dt.to_netcdf(tmp) @pytest.mark.skipif( ON_WINDOWS,