From b3e5b28bf0d01e27a3a368382fa986f36918a728 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 19 Aug 2025 17:44:21 -0700 Subject: [PATCH 1/9] Support DataArray objects in DataTree.from_dict Fixes #9539, #9486 --- xarray/core/datatree.py | 99 ++++++++++++++++++++++++++--------- xarray/core/treenode.py | 4 ++ xarray/tests/test_datatree.py | 55 +++++++++++++++++-- 3 files changed, 127 insertions(+), 31 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 5fe1362c3c6..e453bafb333 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -4,7 +4,7 @@ import io import itertools import textwrap -from collections import ChainMap +from collections import ChainMap, defaultdict from collections.abc import ( Callable, Hashable, @@ -12,6 +12,7 @@ Iterator, Mapping, ) +from dataclasses import dataclass, field from html import escape from os import PathLike from typing import ( @@ -441,6 +442,17 @@ def map( # type: ignore[override] return Dataset(variables, attrs=attrs) +@dataclass +class _CoordWrapper: + value: CoercibleValue + + +@dataclass +class _DatasetArgs: + data_vars: dict[str, CoercibleValue] = field(default_factory=dict) + coords: dict[str, CoercibleValue] = field(default_factory=dict) + + class DataTree( NamedNode, DataTreeAggregations, @@ -1157,8 +1169,8 @@ def drop_nodes( @classmethod def from_dict( cls, - d: Mapping[str, Dataset | DataTree | None], - /, + data: Mapping[str, CoercibleValue | Dataset | DataTree | None] | None = None, + coords: Mapping[str, CoercibleValue] | None = None, name: str | None = None, ) -> Self: """ @@ -1166,8 +1178,9 @@ def from_dict( Parameters ---------- - d : dict-like - A mapping from path names to xarray.Dataset or DataTree objects. + data : dict-like, optional + A mapping from path names to DataTree or Dataset objects, or objects + coercible into a DataArray. Path names are to be given as unix-like path. If path names containing more than one part are given, new tree nodes will be @@ -1175,6 +1188,8 @@ def from_dict( To assign data to the root node of the tree use "", ".", "/" or "./" as the path. + coords : dict-like, optional + A mapping from path names to objects coercible into a DataArray. name : Hashable | None, optional Name for the root node of the tree. Default is None. @@ -1186,19 +1201,53 @@ def from_dict( ----- If your dictionary is nested you will need to flatten it before using this method. """ - # Find any values corresponding to the root - d_cast = dict(d) - root_data = None - for key in ("", ".", "/", "./"): - if key in d_cast: - if root_data is not None: + if data is None: + data = {} + + if coords is None: + coords = {} + + # Canonicalize and unify paths between `data` and `coords` + nodes: dict[ + NodePath, _CoordWrapper | CoercibleValue | Dataset | DataTree | None + ] = {} + for key, value in data.items(): + path = NodePath(key).absolute() + if path in nodes: + raise ValueError( + f"multiple entries found corresponding to node {str(path)!r}" + ) + nodes[path] = value + for key, value in coords.items(): + path = NodePath(key).absolute() + if path in nodes: + raise ValueError( + f"multiple entries found corresponding to node {str(path)!r}" + ) + nodes[path] = _CoordWrapper(value) + + # Merge nodes corresponding to DataArrays into Datasets + dataset_args: defaultdict[NodePath, _DatasetArgs] = defaultdict(_DatasetArgs) + for path in list(nodes): + node = nodes[path] + if node is not None and not isinstance(node, Dataset | DataTree): + if path.parent == path: + raise ValueError("cannot set DataArray value at root") + if path.parent in nodes: raise ValueError( - "multiple entries found corresponding to the root node" + f"cannot set DataArray value at {str(path)!r} when parent node at {str(path.parent)!r} is also set" ) - root_data = d_cast.pop(key) + del nodes[path] + if isinstance(node, _CoordWrapper): + dataset_args[path.parent].coords[path.name] = node.value + else: + dataset_args[path.parent].data_vars[path.name] = node + for path, args in dataset_args.items(): + nodes[path] = Dataset(args.data_vars, args.coords) # Create the root node - if isinstance(root_data, DataTree): + root_data = nodes.pop(NodePath("/"), None) + if isinstance(root_data, cls): obj = root_data.copy() obj.name = name elif root_data is None or isinstance(root_data, Dataset): @@ -1211,19 +1260,19 @@ def from_dict( def depth(item) -> int: pathstr, _ = item - return len(NodePath(pathstr).parts) + return len(pathstr.parts) - if d_cast: - # Populate tree with children determined from data_objects mapping + if nodes: + # Populate tree with children # Sort keys by depth so as to insert nodes from root first (see GH issue #9276) - for path, data in sorted(d_cast.items(), key=depth): + for path, node in sorted(nodes.items(), key=depth): # Create and set new node - if isinstance(data, DataTree): - new_node = data.copy() - elif isinstance(data, Dataset) or data is None: - new_node = cls(dataset=data) + if isinstance(node, DataTree): + new_node = node.copy() + elif isinstance(node, Dataset) or node is None: + new_node = cls(dataset=node) else: - raise TypeError(f"invalid values: {data}") + raise TypeError(f"invalid values: {node}") obj._set_item( path, new_node, @@ -1231,9 +1280,7 @@ def depth(item) -> int: new_nodes_along_path=True, ) - # TODO: figure out why mypy is raising an error here, likely something - # to do with the return type of Dataset.copy() - return obj # type: ignore[return-value] + return obj def to_dict(self, relative: bool = False) -> dict[str, Dataset]: """ diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index d15dc51ec33..58c0efafbdb 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -42,6 +42,10 @@ def __init__(self, *pathsegments): ) # TODO should we also forbid suffixes to avoid node names with dots in them? + def absolute(self) -> Self: + """Convert into an absolute path.""" + return type(self)("/", *self.parts) + class TreeNode: """ diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 7c114d31104..e9637af6148 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -905,10 +905,49 @@ def test_insertion_order(self) -> None: # despite 'Bart' coming before 'Lisa' when sorted alphabetically assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"] - def test_array_values(self) -> None: + def test_array_values_dataarray(self) -> None: + expected = DataTree(dataset=Dataset({"a": 1})) + actual = DataTree.from_dict({"a": DataArray(1)}) + assert_identical(actual, expected) + + def test_array_values_scalars(self) -> None: + expected = DataTree( + dataset=Dataset({"a": 1}), + children={"b": DataTree(Dataset({"c": 2, "d": 3}))}, + ) + actual = DataTree.from_dict({"a": 1, "b/c": 2, "b/d": 3}) + assert_identical(actual, expected) + + def test_array_values_deep(self) -> None: + expected = DataTree( + children={"a": DataTree(children={"b": DataTree(Dataset({"c": 1}))})} + ) + actual = DataTree.from_dict(data={"a/b/c": 1}) + assert_identical(actual, expected) + + def test_array_values_data_and_coords(self) -> None: + expected = DataTree(dataset=Dataset({"a": 1}, coords={"b": 2})) + actual = DataTree.from_dict(data={"a": 1}, coords={"b": 2}) + assert_identical(actual, expected) + + def test_array_values_new_name(self) -> None: + expected = DataTree(dataset=Dataset({"foo": 1})) data = {"foo": xr.DataArray(1, name="bar")} - with pytest.raises(TypeError): - DataTree.from_dict(data) # type: ignore[arg-type] + actual = DataTree.from_dict(data) + assert_identical(actual, expected) + + def test_array_values_at_root(self) -> None: + with pytest.raises(ValueError, match="cannot set DataArray value at root"): + DataTree.from_dict({"/": 1}) + + def test_array_values_parent_node_also_set(self) -> None: + with pytest.raises( + ValueError, + match=re.escape( + r"cannot set DataArray value at '/a' when parent node at '/' is also set" + ), + ): + DataTree.from_dict({"/": Dataset(), "/a": 1}) def test_relative_paths(self) -> None: tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None}) @@ -937,10 +976,16 @@ def test_root_keys(self): actual = DataTree.from_dict({"./": ds}) assert_identical(actual, expected) + def test_multiple_entries(self): + with pytest.raises( + ValueError, match="multiple entries found corresponding to node '/'" + ): + DataTree.from_dict({"": None, ".": None}) + with pytest.raises( - ValueError, match="multiple entries found corresponding to the root node" + ValueError, match="multiple entries found corresponding to node '/a'" ): - DataTree.from_dict({"": ds, "/": ds}) + DataTree.from_dict({"a": None, "/a": None}) def test_name(self): tree = DataTree.from_dict({"/": None}, name="foo") From 800205c44dd716ed290f916091282ecde48116d0 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 19 Aug 2025 21:57:32 -0700 Subject: [PATCH 2/9] Add docs --- doc/whats-new.rst | 4 +++ xarray/core/datatree.py | 73 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a46dba9f15a..ee334fb8108 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -20,6 +20,10 @@ New Features - ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and :py:meth:`DataTree.to_zarr`. By `Stephan Hoyer `_. +- :py:func:`DataTree.from_dict` now supports passing in ``DataArray`` values, + and has a ``coords`` argument for specifying coordinates as ``DataArray`` + objects (:pull:`10658`). + By `Stephan Hoyer `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e453bafb333..fdd8dc6ce60 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1179,12 +1179,14 @@ def from_dict( Parameters ---------- data : dict-like, optional - A mapping from path names to DataTree or Dataset objects, or objects - coercible into a DataArray. + A mapping from path names to ``None`` (indicating an empty node), + ``DataTree`` or ``Dataset`` objects, or objects coercible into a + ``DataArray``. - Path names are to be given as unix-like path. If path names - containing more than one part are given, new tree nodes will be - constructed as necessary. + Path names should be given as unix-like paths, either absolute + (/path/to/item) or relative to the root node (path/to/item). If path + names containing more than one part are given, new tree nodes will + be constructed automatically as necessary. To assign data to the root node of the tree use "", ".", "/" or "./" as the path. @@ -1197,9 +1199,67 @@ def from_dict( ------- DataTree + See also + -------- + Dataset + Notes ----- If your dictionary is nested you will need to flatten it before using this method. + + Examples + -------- + + Construct a tree from a dict of Dataset objects: + + >>> dt = DataTree.from_dict( + ... { + ... "/": Dataset(coords={"time": [1, 2, 3]}), + ... "/ocean": Dataset( + ... { + ... "temperature": ("time", [4, 5, 6]), + ... "salinity": ("time", [7, 8, 9]), + ... } + ... ), + ... "/atmosphere": Dataset( + ... { + ... "temperature": ("time", [2, 3, 4]), + ... "humidity": ("time", [4, 5, 6]), + ... } + ... ), + ... } + ... ) + >>> dt + + Group: / + │ Dimensions: (time: 3) + │ Coordinates: + │ * time (time) int64 24B 1 2 3 + ├── Group: /ocean + │ Dimensions: (time: 3) + │ Data variables: + │ temperature (time) int64 24B 4 5 6 + │ salinity (time) int64 24B 7 8 9 + └── Group: /atmosphere + Dimensions: (time: 3) + Data variables: + temperature (time) int64 24B 2 3 4 + humidity (time) int64 24B 4 5 6 + + Or equivalently, use a dict of values that can be converted into + `DataArray` objects, with syntax similar to the Dataset constructor: + + >>> dt2 = DataTree.from_dict( + ... data={ + ... "/ocean/temperature": ("time", [4, 5, 6]), + ... "/ocean/salinity": ("time", [7, 8, 9]), + ... "/atmosphere/temperature": ("time", [2, 3, 4]), + ... "/atmosphere/humidity": ("time", [3, 4, 5]), + ... }, + ... coords={"/time": [1, 2, 3]}, + ... ) + >>> assert dt.identical(dt2) + """ if data is None: data = {} @@ -1235,7 +1295,8 @@ def from_dict( raise ValueError("cannot set DataArray value at root") if path.parent in nodes: raise ValueError( - f"cannot set DataArray value at {str(path)!r} when parent node at {str(path.parent)!r} is also set" + f"cannot set DataArray value at {str(path)!r} when " + f"parent node at {str(path.parent)!r} is also set" ) del nodes[path] if isinstance(node, _CoordWrapper): From 1baf00289a72da3c7ac43f39460027598be1e867 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 20 Aug 2025 11:59:04 -0700 Subject: [PATCH 3/9] Add support for flattening in from_dict --- doc/whats-new.rst | 6 ++-- xarray/core/datatree.py | 52 +++++++++++++++++++++++------------ xarray/core/types.py | 4 +++ xarray/core/utils.py | 21 +++++++++++++- xarray/tests/test_datatree.py | 21 ++++++++++++++ xarray/tests/test_utils.py | 8 ++++++ 6 files changed, 91 insertions(+), 21 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ee334fb8108..fd6dbc3b72b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -20,9 +20,9 @@ New Features - ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and :py:meth:`DataTree.to_zarr`. By `Stephan Hoyer `_. -- :py:func:`DataTree.from_dict` now supports passing in ``DataArray`` values, - and has a ``coords`` argument for specifying coordinates as ``DataArray`` - objects (:pull:`10658`). +- :py:func:`DataTree.from_dict` now supports passing in ``DataArray`` and nested + dictionary values, and has a ``coords`` argument for specifying coordinates as + ``DataArray`` objects (:pull:`10658`). By `Stephan Hoyer `_. Breaking changes diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index fdd8dc6ce60..1a6838e1815 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -86,6 +86,7 @@ DtCompatible, ErrorOptions, ErrorOptionsWithWarn, + NestedDict, NetcdfWriteModes, T_ChunkDimFreq, T_ChunksFreq, @@ -1169,8 +1170,16 @@ def drop_nodes( @classmethod def from_dict( cls, - data: Mapping[str, CoercibleValue | Dataset | DataTree | None] | None = None, - coords: Mapping[str, CoercibleValue] | None = None, + data: Mapping[ + str, + CoercibleValue + | Dataset + | DataTree + | None + | NestedDict[CoercibleValue | Dataset | DataTree | None], + ] + | None = None, + coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = None, name: str | None = None, ) -> Self: """ @@ -1180,18 +1189,21 @@ def from_dict( ---------- data : dict-like, optional A mapping from path names to ``None`` (indicating an empty node), - ``DataTree`` or ``Dataset`` objects, or objects coercible into a - ``DataArray``. + ``DataTree``, ``Dataset``, objects coercible into a ``DataArray`` or + a nested dictionary of any of the above types. Path names should be given as unix-like paths, either absolute (/path/to/item) or relative to the root node (path/to/item). If path names containing more than one part are given, new tree nodes will be constructed automatically as necessary. + Nested dictionaries are automatically flattened. + To assign data to the root node of the tree use "", ".", "/" or "./" as the path. coords : dict-like, optional - A mapping from path names to objects coercible into a DataArray. + A mapping from path names to objects coercible into a DataArray, or + nested dictionaries of coercible objects. name : Hashable | None, optional Name for the root node of the tree. Default is None. @@ -1203,10 +1215,6 @@ def from_dict( -------- Dataset - Notes - ----- - If your dictionary is nested you will need to flatten it before using this method. - Examples -------- @@ -1260,6 +1268,19 @@ def from_dict( ... ) >>> assert dt.identical(dt2) + Nested dictionaries are automatically flattened: + + >>> DataTree.from_dict({"a": {"b": {"c": {"x": 1, "y": 2}}}}) + + Group: / + └── Group: /a + └── Group: /a/b + └── Group: /a/b/c + Dimensions: () + Data variables: + x int64 8B 1 + y int64 8B 2 + """ if data is None: data = {} @@ -1268,23 +1289,20 @@ def from_dict( coords = {} # Canonicalize and unify paths between `data` and `coords` + flat_data_and_coords = itertools.chain( + utils.flat_items(data), + ((k, _CoordWrapper(v)) for k, v in utils.flat_items(coords)), + ) nodes: dict[ NodePath, _CoordWrapper | CoercibleValue | Dataset | DataTree | None ] = {} - for key, value in data.items(): + for key, value in flat_data_and_coords: path = NodePath(key).absolute() if path in nodes: raise ValueError( f"multiple entries found corresponding to node {str(path)!r}" ) nodes[path] = value - for key, value in coords.items(): - path = NodePath(key).absolute() - if path in nodes: - raise ValueError( - f"multiple entries found corresponding to node {str(path)!r}" - ) - nodes[path] = _CoordWrapper(value) # Merge nodes corresponding to DataArrays into Datasets dataset_args: defaultdict[NodePath, _DatasetArgs] = defaultdict(_DatasetArgs) diff --git a/xarray/core/types.py b/xarray/core/types.py index 736a11f5f17..26fabaf0f5f 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -304,6 +304,10 @@ def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ... def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ... +_T = TypeVar("_T") +NestedDict = dict[str, "NestedDict[_T] | _T"] + + AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e490fc05c2f..a33e77fd8f4 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -97,7 +97,7 @@ ) if TYPE_CHECKING: - from xarray.core.types import Dims, ErrorOptionsWithWarn + from xarray.core.types import Dims, ErrorOptionsWithWarn, NestedDict K = TypeVar("K") V = TypeVar("V") @@ -318,6 +318,25 @@ def remove_incompatible_items( del first_dict[k] +def flat_items( + nested: Mapping[str, NestedDict[T] | T], + prefix: str | None = None, + separator: str = "/", +) -> Iterable[tuple[str, T]]: + """Yields flat items from a nested dictionary of dicts. + + Notes: + - Only dict subclasses are flattened. + - Duplicate items are not removed. These should be checked separately. + """ + for key, value in nested.items(): + key = prefix + separator + key if prefix is not None else key + if isinstance(value, dict): + yield from flat_items(value, key, separator) + else: + yield key, value + + def is_full_slice(value: Any) -> bool: return isinstance(value, slice) and value == slice(None) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index e9637af6148..cbac7b3ee48 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -925,11 +925,32 @@ def test_array_values_deep(self) -> None: actual = DataTree.from_dict(data={"a/b/c": 1}) assert_identical(actual, expected) + def test_nested_array_values(self) -> None: + expected = DataTree( + children={"a": DataTree(children={"b": DataTree(Dataset({"c": 1}))})} + ) + actual = DataTree.from_dict({"a": {"b": {"c": 1}}}) + assert_identical(actual, expected) + + def test_nested_array_values_duplicates(self) -> None: + with pytest.raises( + ValueError, + match=re.escape("multiple entries found corresponding to node '/a/b'"), + ): + DataTree.from_dict({"a": {"b": 1}, "a/b": 2}) + def test_array_values_data_and_coords(self) -> None: expected = DataTree(dataset=Dataset({"a": 1}, coords={"b": 2})) actual = DataTree.from_dict(data={"a": 1}, coords={"b": 2}) assert_identical(actual, expected) + def test_data_and_coords_conflicting(self) -> None: + with pytest.raises( + ValueError, + match=re.escape("multiple entries found corresponding to node '/a'"), + ): + DataTree.from_dict(data={"a": 1}, coords={"a": 2}) + def test_array_values_new_name(self) -> None: expected = DataTree(dataset=Dataset({"foo": 1})) data = {"foo": xr.DataArray(1, name="bar")} diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 0e6bbf29a45..e37e77909c1 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -11,6 +11,7 @@ from xarray.core.utils import ( attempt_import, either_dict_or_kwargs, + flat_items, infix_dims, iterate_nested, ) @@ -151,6 +152,13 @@ def test_filtered(self): assert dict(x) == {"a": 1} +def test_flat_items() -> None: + mapping = {"x": {"y": 1, "z": 2}, "x/y": 3} + actual = list(flat_items(mapping)) + expected = [("x/y", 1), ("x/z", 2), ("x/y", 3)] + assert actual == expected + + def test_repr_object(): obj = utils.ReprObject("foo") assert repr(obj) == "foo" From 914a20b15b39590f9d3a8413bf49a5c94e7dd95f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 20 Aug 2025 12:08:43 -0700 Subject: [PATCH 4/9] Fix doctest --- xarray/core/datatree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1a6838e1815..1ec2ebee1a1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1232,7 +1232,7 @@ def from_dict( ... "/atmosphere": Dataset( ... { ... "temperature": ("time", [2, 3, 4]), - ... "humidity": ("time", [4, 5, 6]), + ... "humidity": ("time", [3, 4, 5]), ... } ... ), ... } @@ -1252,7 +1252,7 @@ def from_dict( Dimensions: (time: 3) Data variables: temperature (time) int64 24B 2 3 4 - humidity (time) int64 24B 4 5 6 + humidity (time) int64 24B 3 4 5 Or equivalently, use a dict of values that can be converted into `DataArray` objects, with syntax similar to the Dataset constructor: From 44f1b83a7ca2f51856a955395207d2be9878644d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 20 Aug 2025 12:48:09 -0700 Subject: [PATCH 5/9] Add note about DataTree.from_dict vs DataArray.from_dict --- xarray/core/datatree.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1ec2ebee1a1..f60d3306681 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1215,6 +1215,13 @@ def from_dict( -------- Dataset + Notes + ----- + ``DataTree.from_dict`` serves a conceptually different purpose from + ``Dataset.from_dict`` and ``DataArray.from_dict``. It converts a + hierarchy of Xarray objects into a DataTree, rather than converting pure + Python data structures. + Examples -------- From 28fe9d01cd3b20f393d4f4d02c5e8acff785accc Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 4 Sep 2025 10:24:05 -0700 Subject: [PATCH 6/9] Fix whats new --- doc/whats-new.rst | 7 ------- 1 file changed, 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index deaac7b8631..aaba60faf29 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,13 +13,6 @@ v2025.09.1 (unreleased) New Features ~~~~~~~~~~~~ -- Added :py:func:`load_datatree` for loading ``DataTree`` objects into memory - from disk. It has the same relationship to :py:func:`open_datatree`, as - :py:func:`load_dataset` has to :py:func:`open_dataset`. - By `Stephan Hoyer `_. -- ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and - :py:meth:`DataTree.to_zarr`. - By `Stephan Hoyer `_. - :py:func:`DataTree.from_dict` now supports passing in ``DataArray`` and nested dictionary values, and has a ``coords`` argument for specifying coordinates as ``DataArray`` objects (:pull:`10658`). From 0261445c4ecb589c9c8f5657830ffeded3b98365 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 4 Sep 2025 10:59:21 -0700 Subject: [PATCH 7/9] Require nested=True for processing nested items --- xarray/core/datatree.py | 57 ++++++++++++++++++++++++++++++----- xarray/tests/test_datatree.py | 26 ++++++++++++++-- 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index f60d3306681..1832a7a9d3e 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1167,6 +1167,36 @@ def drop_nodes( result._replace_node(children=children_to_keep) return result + @overload + @classmethod + def from_dict( + cls, + data: Mapping[str, CoercibleValue | Dataset | DataTree | None] | None = ..., + coords: Mapping[str, CoercibleValue] | None = ..., + *, + name: str | None = ..., + nested: Literal[False] = ..., + ) -> Self: ... + + @overload + @classmethod + def from_dict( + cls, + data: Mapping[ + str, + CoercibleValue + | Dataset + | DataTree + | None + | NestedDict[CoercibleValue | Dataset | DataTree | None], + ] + | None = ..., + coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = ..., + *, + name: str | None = ..., + nested: Literal[True] = ..., + ) -> Self: ... + @classmethod def from_dict( cls, @@ -1180,7 +1210,9 @@ def from_dict( ] | None = None, coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = None, + *, name: str | None = None, + nested: bool = False, ) -> Self: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1197,8 +1229,6 @@ def from_dict( names containing more than one part are given, new tree nodes will be constructed automatically as necessary. - Nested dictionaries are automatically flattened. - To assign data to the root node of the tree use "", ".", "/" or "./" as the path. coords : dict-like, optional @@ -1206,6 +1236,9 @@ def from_dict( nested dictionaries of coercible objects. name : Hashable | None, optional Name for the root node of the tree. Default is None. + nested : bool, optional + If true, nested dictionaries in ``data`` and ``coords`` are + automatically flattened. Returns ------- @@ -1275,9 +1308,9 @@ def from_dict( ... ) >>> assert dt.identical(dt2) - Nested dictionaries are automatically flattened: + Nested dictionaries are automatically flattened if ``nested=True``: - >>> DataTree.from_dict({"a": {"b": {"c": {"x": 1, "y": 2}}}}) + >>> DataTree.from_dict({"a": {"b": {"c": {"x": 1, "y": 2}}}}, nested=True) Group: / └── Group: /a @@ -1295,10 +1328,13 @@ def from_dict( if coords is None: coords = {} + data_items = utils.flat_items(data) if nested else data.items() + coords_items = utils.flat_items(coords) if nested else coords.items() + # Canonicalize and unify paths between `data` and `coords` flat_data_and_coords = itertools.chain( - utils.flat_items(data), - ((k, _CoordWrapper(v)) for k, v in utils.flat_items(coords)), + data_items, + ((k, _CoordWrapper(v)) for k, v in coords_items), ) nodes: dict[ NodePath, _CoordWrapper | CoercibleValue | Dataset | DataTree | None @@ -1329,7 +1365,14 @@ def from_dict( else: dataset_args[path.parent].data_vars[path.name] = node for path, args in dataset_args.items(): - nodes[path] = Dataset(args.data_vars, args.coords) + try: + nodes[path] = Dataset(args.data_vars, args.coords) + except (ValueError, TypeError) as e: + raise type(e)( + "failed to construct xarray.Dataset for DataTree node at " + f"{str(path)!r} with data_vars={args.data_vars} and " + f"coords={args.coords}" + ) from e # Create the root node root_data = nodes.pop(NodePath("/"), None) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index a9c58257931..f04f1aa83ba 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -918,7 +918,17 @@ def test_array_values_scalars(self) -> None: actual = DataTree.from_dict({"a": 1, "b/c": 2, "b/d": 3}) assert_identical(actual, expected) - def test_array_values_deep(self) -> None: + def test_invalid_values(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + r"failed to construct xarray.Dataset for DataTree node at '/' " + r"with data_vars={'a': set()} and coords={}" + ), + ): + DataTree.from_dict({"a": set()}) + + def test_array_values_nested_key(self) -> None: expected = DataTree( children={"a": DataTree(children={"b": DataTree(Dataset({"c": 1}))})} ) @@ -929,15 +939,25 @@ def test_nested_array_values(self) -> None: expected = DataTree( children={"a": DataTree(children={"b": DataTree(Dataset({"c": 1}))})} ) - actual = DataTree.from_dict({"a": {"b": {"c": 1}}}) + actual = DataTree.from_dict({"a": {"b": {"c": 1}}}, nested=True) assert_identical(actual, expected) + def test_nested_array_values_without_nested_kwarg(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + r"failed to construct xarray.Dataset for DataTree node at '/' " + r"with data_vars={'a': {'b': {'c': 1}}} and coords={}" + ), + ): + DataTree.from_dict({"a": {"b": {"c": 1}}}) + def test_nested_array_values_duplicates(self) -> None: with pytest.raises( ValueError, match=re.escape("multiple entries found corresponding to node '/a/b'"), ): - DataTree.from_dict({"a": {"b": 1}, "a/b": 2}) + DataTree.from_dict({"a": {"b": 1}, "a/b": 2}, nested=True) def test_array_values_data_and_coords(self) -> None: expected = DataTree(dataset=Dataset({"a": 1}, coords={"b": 2})) From 9fa34cf694ad22b05f3df4526b610391ea32ae4a Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 4 Sep 2025 12:01:04 -0700 Subject: [PATCH 8/9] Better error message --- xarray/core/datatree.py | 16 ++++++++++++++-- xarray/tests/test_datatree.py | 5 +++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1832a7a9d3e..c15898e22fa 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1328,8 +1328,20 @@ def from_dict( if coords is None: coords = {} - data_items = utils.flat_items(data) if nested else data.items() - coords_items = utils.flat_items(coords) if nested else coords.items() + if nested: + data_items = utils.flat_items(data) + coords_items = utils.flat_items(coords) + else: + data_items = data.items() + coords_items = coords.items() + for arg_name, items in [("data", data_items), ("coords", coords_items)]: + for key, value in items: + if isinstance(value, dict): + raise TypeError( + f"{arg_name} contains a dict value at {key=}, " + "which is not a valid argument to " + f"DataTree.from_dict() with nested=False: {value}" + ) # Canonicalize and unify paths between `data` and `coords` flat_data_and_coords = itertools.chain( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index f04f1aa83ba..a368c56dee9 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -946,8 +946,9 @@ def test_nested_array_values_without_nested_kwarg(self) -> None: with pytest.raises( TypeError, match=re.escape( - r"failed to construct xarray.Dataset for DataTree node at '/' " - r"with data_vars={'a': {'b': {'c': 1}}} and coords={}" + r"data contains a dict value at key='a', which is not a valid " + r"argument to DataTree.from_dict() with nested=False: " + r"{'b': {'c': 1}}" ), ): DataTree.from_dict({"a": {"b": {"c": 1}}}) From 8de05376d9be191432b0c75bb4617df7d4c9b92e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 18 Sep 2025 10:13:38 -0700 Subject: [PATCH 9/9] improve typing --- xarray/core/datatree.py | 41 ++++++++++++++++------------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 14295aa2955..056e6442a21 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -22,6 +22,7 @@ Literal, NoReturn, ParamSpec, + TypeAlias, TypeVar, Union, overload, @@ -443,6 +444,9 @@ def map( # type: ignore[override] return Dataset(variables, attrs=attrs) +FromDictDataValue: TypeAlias = "CoercibleValue | Dataset | DataTree | None" + + @dataclass class _CoordWrapper: value: CoercibleValue @@ -1171,7 +1175,7 @@ def drop_nodes( @classmethod def from_dict( cls, - data: Mapping[str, CoercibleValue | Dataset | DataTree | None] | None = ..., + data: Mapping[str, FromDictDataValue] | None = ..., coords: Mapping[str, CoercibleValue] | None = ..., *, name: str | None = ..., @@ -1182,15 +1186,9 @@ def from_dict( @classmethod def from_dict( cls, - data: Mapping[ - str, - CoercibleValue - | Dataset - | DataTree - | None - | NestedDict[CoercibleValue | Dataset | DataTree | None], - ] - | None = ..., + data: ( + Mapping[str, FromDictDataValue | NestedDict[FromDictDataValue]] | None + ) = ..., coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = ..., *, name: str | None = ..., @@ -1200,15 +1198,9 @@ def from_dict( @classmethod def from_dict( cls, - data: Mapping[ - str, - CoercibleValue - | Dataset - | DataTree - | None - | NestedDict[CoercibleValue | Dataset | DataTree | None], - ] - | None = None, + data: ( + Mapping[str, FromDictDataValue | NestedDict[FromDictDataValue]] | None + ) = None, coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = None, *, name: str | None = None, @@ -1348,9 +1340,7 @@ def from_dict( data_items, ((k, _CoordWrapper(v)) for k, v in coords_items), ) - nodes: dict[ - NodePath, _CoordWrapper | CoercibleValue | Dataset | DataTree | None - ] = {} + nodes: dict[NodePath, _CoordWrapper | FromDictDataValue] = {} for key, value in flat_data_and_coords: path = NodePath(key).absolute() if path in nodes: @@ -1389,6 +1379,7 @@ def from_dict( # Create the root node root_data = nodes.pop(NodePath("/"), None) if isinstance(root_data, cls): + # use cls so type-checkers understand this method returns Self obj = root_data.copy() obj.name = name elif root_data is None or isinstance(root_data, Dataset): @@ -1399,9 +1390,9 @@ def from_dict( f"or DataTree, got {type(root_data)}" ) - def depth(item) -> int: - pathstr, _ = item - return len(pathstr.parts) + def depth(item: tuple[NodePath, object]) -> int: + node_path, _ = item + return len(node_path.parts) if nodes: # Populate tree with children