diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1aef6b9db3d..db86db5157f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,9 @@ v2025.09.1 (unreleased) New Features ~~~~~~~~~~~~ +- :py:func:`DataTree.from_dict` now supports passing in ``DataArray`` and nested + dictionary values, and has a ``coords`` argument for specifying coordinates as + ``DataArray`` objects (:pull:`10658`). - ``engine='netcdf4'`` now supports reading and writing in-memory netCDF files. All of Xarray's netCDF backends now support in-memory reads and writes (:pull:`10624`). diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 99441b1a8d4..056e6442a21 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -4,7 +4,7 @@ import io import itertools import textwrap -from collections import ChainMap +from collections import ChainMap, defaultdict from collections.abc import ( Callable, Hashable, @@ -12,6 +12,7 @@ Iterator, Mapping, ) +from dataclasses import dataclass, field from html import escape from os import PathLike from typing import ( @@ -21,6 +22,7 @@ Literal, NoReturn, ParamSpec, + TypeAlias, TypeVar, Union, overload, @@ -85,6 +87,7 @@ DtCompatible, ErrorOptions, ErrorOptionsWithWarn, + NestedDict, NetcdfWriteModes, T_ChunkDimFreq, T_ChunksFreq, @@ -441,6 +444,20 @@ def map( # type: ignore[override] return Dataset(variables, attrs=attrs) +FromDictDataValue: TypeAlias = "CoercibleValue | Dataset | DataTree | None" + + +@dataclass +class _CoordWrapper: + value: CoercibleValue + + +@dataclass +class _DatasetArgs: + data_vars: dict[str, CoercibleValue] = field(default_factory=dict) + coords: dict[str, CoercibleValue] = field(default_factory=dict) + + class DataTree( NamedNode, DataTreeAggregations, @@ -1154,51 +1171,215 @@ def drop_nodes( result._replace_node(children=children_to_keep) return result + @overload + @classmethod + def from_dict( + cls, + data: Mapping[str, FromDictDataValue] | None = ..., + coords: Mapping[str, CoercibleValue] | None = ..., + *, + name: str | None = ..., + nested: Literal[False] = ..., + ) -> Self: ... + + @overload + @classmethod + def from_dict( + cls, + data: ( + Mapping[str, FromDictDataValue | NestedDict[FromDictDataValue]] | None + ) = ..., + coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = ..., + *, + name: str | None = ..., + nested: Literal[True] = ..., + ) -> Self: ... + @classmethod def from_dict( cls, - d: Mapping[str, Dataset | DataTree | None], - /, + data: ( + Mapping[str, FromDictDataValue | NestedDict[FromDictDataValue]] | None + ) = None, + coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = None, + *, name: str | None = None, + nested: bool = False, ) -> Self: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. Parameters ---------- - d : dict-like - A mapping from path names to xarray.Dataset or DataTree objects. + data : dict-like, optional + A mapping from path names to ``None`` (indicating an empty node), + ``DataTree``, ``Dataset``, objects coercible into a ``DataArray`` or + a nested dictionary of any of the above types. - Path names are to be given as unix-like path. If path names - containing more than one part are given, new tree nodes will be - constructed as necessary. + Path names should be given as unix-like paths, either absolute + (/path/to/item) or relative to the root node (path/to/item). If path + names containing more than one part are given, new tree nodes will + be constructed automatically as necessary. To assign data to the root node of the tree use "", ".", "/" or "./" as the path. + coords : dict-like, optional + A mapping from path names to objects coercible into a DataArray, or + nested dictionaries of coercible objects. name : Hashable | None, optional Name for the root node of the tree. Default is None. + nested : bool, optional + If true, nested dictionaries in ``data`` and ``coords`` are + automatically flattened. Returns ------- DataTree + See also + -------- + Dataset + Notes ----- - If your dictionary is nested you will need to flatten it before using this method. - """ - # Find any values corresponding to the root - d_cast = dict(d) - root_data = None - for key in ("", ".", "/", "./"): - if key in d_cast: - if root_data is not None: + ``DataTree.from_dict`` serves a conceptually different purpose from + ``Dataset.from_dict`` and ``DataArray.from_dict``. It converts a + hierarchy of Xarray objects into a DataTree, rather than converting pure + Python data structures. + + Examples + -------- + + Construct a tree from a dict of Dataset objects: + + >>> dt = DataTree.from_dict( + ... { + ... "/": Dataset(coords={"time": [1, 2, 3]}), + ... "/ocean": Dataset( + ... { + ... "temperature": ("time", [4, 5, 6]), + ... "salinity": ("time", [7, 8, 9]), + ... } + ... ), + ... "/atmosphere": Dataset( + ... { + ... "temperature": ("time", [2, 3, 4]), + ... "humidity": ("time", [3, 4, 5]), + ... } + ... ), + ... } + ... ) + >>> dt + + Group: / + │ Dimensions: (time: 3) + │ Coordinates: + │ * time (time) int64 24B 1 2 3 + ├── Group: /ocean + │ Dimensions: (time: 3) + │ Data variables: + │ temperature (time) int64 24B 4 5 6 + │ salinity (time) int64 24B 7 8 9 + └── Group: /atmosphere + Dimensions: (time: 3) + Data variables: + temperature (time) int64 24B 2 3 4 + humidity (time) int64 24B 3 4 5 + + Or equivalently, use a dict of values that can be converted into + `DataArray` objects, with syntax similar to the Dataset constructor: + + >>> dt2 = DataTree.from_dict( + ... data={ + ... "/ocean/temperature": ("time", [4, 5, 6]), + ... "/ocean/salinity": ("time", [7, 8, 9]), + ... "/atmosphere/temperature": ("time", [2, 3, 4]), + ... "/atmosphere/humidity": ("time", [3, 4, 5]), + ... }, + ... coords={"/time": [1, 2, 3]}, + ... ) + >>> assert dt.identical(dt2) + + Nested dictionaries are automatically flattened if ``nested=True``: + + >>> DataTree.from_dict({"a": {"b": {"c": {"x": 1, "y": 2}}}}, nested=True) + + Group: / + └── Group: /a + └── Group: /a/b + └── Group: /a/b/c + Dimensions: () + Data variables: + x int64 8B 1 + y int64 8B 2 + + """ + if data is None: + data = {} + + if coords is None: + coords = {} + + if nested: + data_items = utils.flat_items(data) + coords_items = utils.flat_items(coords) + else: + data_items = data.items() + coords_items = coords.items() + for arg_name, items in [("data", data_items), ("coords", coords_items)]: + for key, value in items: + if isinstance(value, dict): + raise TypeError( + f"{arg_name} contains a dict value at {key=}, " + "which is not a valid argument to " + f"DataTree.from_dict() with nested=False: {value}" + ) + + # Canonicalize and unify paths between `data` and `coords` + flat_data_and_coords = itertools.chain( + data_items, + ((k, _CoordWrapper(v)) for k, v in coords_items), + ) + nodes: dict[NodePath, _CoordWrapper | FromDictDataValue] = {} + for key, value in flat_data_and_coords: + path = NodePath(key).absolute() + if path in nodes: + raise ValueError( + f"multiple entries found corresponding to node {str(path)!r}" + ) + nodes[path] = value + + # Merge nodes corresponding to DataArrays into Datasets + dataset_args: defaultdict[NodePath, _DatasetArgs] = defaultdict(_DatasetArgs) + for path in list(nodes): + node = nodes[path] + if node is not None and not isinstance(node, Dataset | DataTree): + if path.parent == path: + raise ValueError("cannot set DataArray value at root") + if path.parent in nodes: raise ValueError( - "multiple entries found corresponding to the root node" + f"cannot set DataArray value at {str(path)!r} when " + f"parent node at {str(path.parent)!r} is also set" ) - root_data = d_cast.pop(key) + del nodes[path] + if isinstance(node, _CoordWrapper): + dataset_args[path.parent].coords[path.name] = node.value + else: + dataset_args[path.parent].data_vars[path.name] = node + for path, args in dataset_args.items(): + try: + nodes[path] = Dataset(args.data_vars, args.coords) + except (ValueError, TypeError) as e: + raise type(e)( + "failed to construct xarray.Dataset for DataTree node at " + f"{str(path)!r} with data_vars={args.data_vars} and " + f"coords={args.coords}" + ) from e # Create the root node - if isinstance(root_data, DataTree): + root_data = nodes.pop(NodePath("/"), None) + if isinstance(root_data, cls): + # use cls so type-checkers understand this method returns Self obj = root_data.copy() obj.name = name elif root_data is None or isinstance(root_data, Dataset): @@ -1209,21 +1390,21 @@ def from_dict( f"or DataTree, got {type(root_data)}" ) - def depth(item) -> int: - pathstr, _ = item - return len(NodePath(pathstr).parts) + def depth(item: tuple[NodePath, object]) -> int: + node_path, _ = item + return len(node_path.parts) - if d_cast: - # Populate tree with children determined from data_objects mapping + if nodes: + # Populate tree with children # Sort keys by depth so as to insert nodes from root first (see GH issue #9276) - for path, data in sorted(d_cast.items(), key=depth): + for path, node in sorted(nodes.items(), key=depth): # Create and set new node - if isinstance(data, DataTree): - new_node = data.copy() - elif isinstance(data, Dataset) or data is None: - new_node = cls(dataset=data) + if isinstance(node, DataTree): + new_node = node.copy() + elif isinstance(node, Dataset) or node is None: + new_node = cls(dataset=node) else: - raise TypeError(f"invalid values: {data}") + raise TypeError(f"invalid values: {node}") obj._set_item( path, new_node, @@ -1231,9 +1412,7 @@ def depth(item) -> int: new_nodes_along_path=True, ) - # TODO: figure out why mypy is raising an error here, likely something - # to do with the return type of Dataset.copy() - return obj # type: ignore[return-value] + return obj def to_dict(self, relative: bool = False) -> dict[str, Dataset]: """ diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index d15dc51ec33..58c0efafbdb 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -42,6 +42,10 @@ def __init__(self, *pathsegments): ) # TODO should we also forbid suffixes to avoid node names with dots in them? + def absolute(self) -> Self: + """Convert into an absolute path.""" + return type(self)("/", *self.parts) + class TreeNode: """ diff --git a/xarray/core/types.py b/xarray/core/types.py index a0d62d30c9f..69cee210798 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -304,6 +304,10 @@ def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ... def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ... +_T = TypeVar("_T") +NestedDict = dict[str, "NestedDict[_T] | _T"] + + AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index d44aff9ff36..ec4edf255f6 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -97,7 +97,7 @@ ) if TYPE_CHECKING: - from xarray.core.types import Dims, ErrorOptionsWithWarn + from xarray.core.types import Dims, ErrorOptionsWithWarn, NestedDict K = TypeVar("K") V = TypeVar("V") @@ -335,6 +335,25 @@ def remove_incompatible_items( del first_dict[k] +def flat_items( + nested: Mapping[str, NestedDict[T] | T], + prefix: str | None = None, + separator: str = "/", +) -> Iterable[tuple[str, T]]: + """Yields flat items from a nested dictionary of dicts. + + Notes: + - Only dict subclasses are flattened. + - Duplicate items are not removed. These should be checked separately. + """ + for key, value in nested.items(): + key = prefix + separator + key if prefix is not None else key + if isinstance(value, dict): + yield from flat_items(value, key, separator) + else: + yield key, value + + def is_full_slice(value: Any) -> bool: return isinstance(value, slice) and value == slice(None) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index b54fd3cb959..a368c56dee9 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -905,10 +905,91 @@ def test_insertion_order(self) -> None: # despite 'Bart' coming before 'Lisa' when sorted alphabetically assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"] - def test_array_values(self) -> None: + def test_array_values_dataarray(self) -> None: + expected = DataTree(dataset=Dataset({"a": 1})) + actual = DataTree.from_dict({"a": DataArray(1)}) + assert_identical(actual, expected) + + def test_array_values_scalars(self) -> None: + expected = DataTree( + dataset=Dataset({"a": 1}), + children={"b": DataTree(Dataset({"c": 2, "d": 3}))}, + ) + actual = DataTree.from_dict({"a": 1, "b/c": 2, "b/d": 3}) + assert_identical(actual, expected) + + def test_invalid_values(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + r"failed to construct xarray.Dataset for DataTree node at '/' " + r"with data_vars={'a': set()} and coords={}" + ), + ): + DataTree.from_dict({"a": set()}) + + def test_array_values_nested_key(self) -> None: + expected = DataTree( + children={"a": DataTree(children={"b": DataTree(Dataset({"c": 1}))})} + ) + actual = DataTree.from_dict(data={"a/b/c": 1}) + assert_identical(actual, expected) + + def test_nested_array_values(self) -> None: + expected = DataTree( + children={"a": DataTree(children={"b": DataTree(Dataset({"c": 1}))})} + ) + actual = DataTree.from_dict({"a": {"b": {"c": 1}}}, nested=True) + assert_identical(actual, expected) + + def test_nested_array_values_without_nested_kwarg(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + r"data contains a dict value at key='a', which is not a valid " + r"argument to DataTree.from_dict() with nested=False: " + r"{'b': {'c': 1}}" + ), + ): + DataTree.from_dict({"a": {"b": {"c": 1}}}) + + def test_nested_array_values_duplicates(self) -> None: + with pytest.raises( + ValueError, + match=re.escape("multiple entries found corresponding to node '/a/b'"), + ): + DataTree.from_dict({"a": {"b": 1}, "a/b": 2}, nested=True) + + def test_array_values_data_and_coords(self) -> None: + expected = DataTree(dataset=Dataset({"a": 1}, coords={"b": 2})) + actual = DataTree.from_dict(data={"a": 1}, coords={"b": 2}) + assert_identical(actual, expected) + + def test_data_and_coords_conflicting(self) -> None: + with pytest.raises( + ValueError, + match=re.escape("multiple entries found corresponding to node '/a'"), + ): + DataTree.from_dict(data={"a": 1}, coords={"a": 2}) + + def test_array_values_new_name(self) -> None: + expected = DataTree(dataset=Dataset({"foo": 1})) data = {"foo": xr.DataArray(1, name="bar")} - with pytest.raises(TypeError): - DataTree.from_dict(data) # type: ignore[arg-type] + actual = DataTree.from_dict(data) + assert_identical(actual, expected) + + def test_array_values_at_root(self) -> None: + with pytest.raises(ValueError, match="cannot set DataArray value at root"): + DataTree.from_dict({"/": 1}) + + def test_array_values_parent_node_also_set(self) -> None: + with pytest.raises( + ValueError, + match=re.escape( + r"cannot set DataArray value at '/a' when parent node at '/' is also set" + ), + ): + DataTree.from_dict({"/": Dataset(), "/a": 1}) def test_relative_paths(self) -> None: tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None}) @@ -937,10 +1018,16 @@ def test_root_keys(self): actual = DataTree.from_dict({"./": ds}) assert_identical(actual, expected) + def test_multiple_entries(self): + with pytest.raises( + ValueError, match="multiple entries found corresponding to node '/'" + ): + DataTree.from_dict({"": None, ".": None}) + with pytest.raises( - ValueError, match="multiple entries found corresponding to the root node" + ValueError, match="multiple entries found corresponding to node '/a'" ): - DataTree.from_dict({"": ds, "/": ds}) + DataTree.from_dict({"a": None, "/a": None}) def test_name(self): tree = DataTree.from_dict({"/": None}, name="foo") diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index f3333d188cc..1261df44f76 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -11,6 +11,7 @@ from xarray.core.utils import ( attempt_import, either_dict_or_kwargs, + flat_items, infix_dims, iterate_nested, ) @@ -151,6 +152,13 @@ def test_filtered(self): assert dict(x) == {"a": 1} +def test_flat_items() -> None: + mapping = {"x": {"y": 1, "z": 2}, "x/y": 3} + actual = list(flat_items(mapping)) + expected = [("x/y", 1), ("x/z", 2), ("x/y", 3)] + assert actual == expected + + def test_repr_object(): obj = utils.ReprObject("foo") assert repr(obj) == "foo"