Skip to content
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ v2025.09.1 (unreleased)
New Features
~~~~~~~~~~~~

- :py:func:`DataTree.from_dict` now supports passing in ``DataArray`` and nested
dictionary values, and has a ``coords`` argument for specifying coordinates as
``DataArray`` objects (:pull:`10658`).
- ``engine='netcdf4'`` now supports reading and writing in-memory netCDF files.
All of Xarray's netCDF backends now support in-memory reads and writes
(:pull:`10624`).
Expand Down
245 changes: 212 additions & 33 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import io
import itertools
import textwrap
from collections import ChainMap
from collections import ChainMap, defaultdict
from collections.abc import (
Callable,
Hashable,
Iterable,
Iterator,
Mapping,
)
from dataclasses import dataclass, field
from html import escape
from os import PathLike
from typing import (
Expand All @@ -21,6 +22,7 @@
Literal,
NoReturn,
ParamSpec,
TypeAlias,
TypeVar,
Union,
overload,
Expand Down Expand Up @@ -85,6 +87,7 @@
DtCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
NestedDict,
NetcdfWriteModes,
T_ChunkDimFreq,
T_ChunksFreq,
Expand Down Expand Up @@ -441,6 +444,20 @@ def map( # type: ignore[override]
return Dataset(variables, attrs=attrs)


FromDictDataValue: TypeAlias = "CoercibleValue | Dataset | DataTree | None"


@dataclass
class _CoordWrapper:
value: CoercibleValue


@dataclass
class _DatasetArgs:
data_vars: dict[str, CoercibleValue] = field(default_factory=dict)
coords: dict[str, CoercibleValue] = field(default_factory=dict)


class DataTree(
NamedNode,
DataTreeAggregations,
Expand Down Expand Up @@ -1154,51 +1171,215 @@ def drop_nodes(
result._replace_node(children=children_to_keep)
return result

@overload
@classmethod
def from_dict(
cls,
data: Mapping[str, FromDictDataValue] | None = ...,
coords: Mapping[str, CoercibleValue] | None = ...,
*,
name: str | None = ...,
nested: Literal[False] = ...,
) -> Self: ...

@overload
@classmethod
def from_dict(
cls,
data: (
Mapping[str, FromDictDataValue | NestedDict[FromDictDataValue]] | None
) = ...,
coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = ...,
*,
name: str | None = ...,
nested: Literal[True] = ...,
) -> Self: ...

@classmethod
def from_dict(
cls,
d: Mapping[str, Dataset | DataTree | None],
/,
data: (
Mapping[str, FromDictDataValue | NestedDict[FromDictDataValue]] | None
) = None,
coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = None,
*,
name: str | None = None,
nested: bool = False,
) -> Self:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.

Parameters
----------
d : dict-like
A mapping from path names to xarray.Dataset or DataTree objects.
data : dict-like, optional
A mapping from path names to ``None`` (indicating an empty node),
``DataTree``, ``Dataset``, objects coercible into a ``DataArray`` or
a nested dictionary of any of the above types.

Path names are to be given as unix-like path. If path names
containing more than one part are given, new tree nodes will be
constructed as necessary.
Path names should be given as unix-like paths, either absolute
(/path/to/item) or relative to the root node (path/to/item). If path
names containing more than one part are given, new tree nodes will
be constructed automatically as necessary.

To assign data to the root node of the tree use "", ".", "/" or "./"
as the path.
coords : dict-like, optional
A mapping from path names to objects coercible into a DataArray, or
nested dictionaries of coercible objects.
name : Hashable | None, optional
Name for the root node of the tree. Default is None.
nested : bool, optional
If true, nested dictionaries in ``data`` and ``coords`` are
automatically flattened.

Returns
-------
DataTree

See also
--------
Dataset

Notes
-----
If your dictionary is nested you will need to flatten it before using this method.
"""
# Find any values corresponding to the root
d_cast = dict(d)
root_data = None
for key in ("", ".", "/", "./"):
if key in d_cast:
if root_data is not None:
``DataTree.from_dict`` serves a conceptually different purpose from
``Dataset.from_dict`` and ``DataArray.from_dict``. It converts a
hierarchy of Xarray objects into a DataTree, rather than converting pure
Python data structures.

Examples
--------

Construct a tree from a dict of Dataset objects:

>>> dt = DataTree.from_dict(
... {
... "/": Dataset(coords={"time": [1, 2, 3]}),
... "/ocean": Dataset(
... {
... "temperature": ("time", [4, 5, 6]),
... "salinity": ("time", [7, 8, 9]),
... }
... ),
... "/atmosphere": Dataset(
... {
... "temperature": ("time", [2, 3, 4]),
... "humidity": ("time", [3, 4, 5]),
... }
... ),
... }
... )
>>> dt
<xarray.DataTree>
Group: /
│ Dimensions: (time: 3)
│ Coordinates:
│ * time (time) int64 24B 1 2 3
├── Group: /ocean
│ Dimensions: (time: 3)
│ Data variables:
│ temperature (time) int64 24B 4 5 6
│ salinity (time) int64 24B 7 8 9
└── Group: /atmosphere
Dimensions: (time: 3)
Data variables:
temperature (time) int64 24B 2 3 4
humidity (time) int64 24B 3 4 5

Or equivalently, use a dict of values that can be converted into
`DataArray` objects, with syntax similar to the Dataset constructor:

>>> dt2 = DataTree.from_dict(
... data={
... "/ocean/temperature": ("time", [4, 5, 6]),
... "/ocean/salinity": ("time", [7, 8, 9]),
... "/atmosphere/temperature": ("time", [2, 3, 4]),
... "/atmosphere/humidity": ("time", [3, 4, 5]),
... },
... coords={"/time": [1, 2, 3]},
... )
>>> assert dt.identical(dt2)

Nested dictionaries are automatically flattened if ``nested=True``:

>>> DataTree.from_dict({"a": {"b": {"c": {"x": 1, "y": 2}}}}, nested=True)
<xarray.DataTree>
Group: /
└── Group: /a
└── Group: /a/b
└── Group: /a/b/c
Dimensions: ()
Data variables:
x int64 8B 1
y int64 8B 2

"""
if data is None:
data = {}

if coords is None:
coords = {}

if nested:
data_items = utils.flat_items(data)
coords_items = utils.flat_items(coords)
else:
data_items = data.items()
coords_items = coords.items()
for arg_name, items in [("data", data_items), ("coords", coords_items)]:
for key, value in items:
if isinstance(value, dict):
raise TypeError(
f"{arg_name} contains a dict value at {key=}, "
"which is not a valid argument to "
f"DataTree.from_dict() with nested=False: {value}"
)

# Canonicalize and unify paths between `data` and `coords`
flat_data_and_coords = itertools.chain(
data_items,
((k, _CoordWrapper(v)) for k, v in coords_items),
)
nodes: dict[NodePath, _CoordWrapper | FromDictDataValue] = {}
for key, value in flat_data_and_coords:
path = NodePath(key).absolute()
if path in nodes:
raise ValueError(
f"multiple entries found corresponding to node {str(path)!r}"
)
nodes[path] = value

# Merge nodes corresponding to DataArrays into Datasets
dataset_args: defaultdict[NodePath, _DatasetArgs] = defaultdict(_DatasetArgs)
for path in list(nodes):
node = nodes[path]
if node is not None and not isinstance(node, Dataset | DataTree):
if path.parent == path:
raise ValueError("cannot set DataArray value at root")
if path.parent in nodes:
raise ValueError(
"multiple entries found corresponding to the root node"
f"cannot set DataArray value at {str(path)!r} when "
f"parent node at {str(path.parent)!r} is also set"
)
root_data = d_cast.pop(key)
del nodes[path]
if isinstance(node, _CoordWrapper):
dataset_args[path.parent].coords[path.name] = node.value
else:
dataset_args[path.parent].data_vars[path.name] = node
for path, args in dataset_args.items():
try:
nodes[path] = Dataset(args.data_vars, args.coords)
except (ValueError, TypeError) as e:
raise type(e)(
"failed to construct xarray.Dataset for DataTree node at "
f"{str(path)!r} with data_vars={args.data_vars} and "
f"coords={args.coords}"
) from e

# Create the root node
if isinstance(root_data, DataTree):
root_data = nodes.pop(NodePath("/"), None)
if isinstance(root_data, cls):
Comment thread
shoyer marked this conversation as resolved.
# use cls so type-checkers understand this method returns Self
obj = root_data.copy()
obj.name = name
elif root_data is None or isinstance(root_data, Dataset):
Expand All @@ -1209,31 +1390,29 @@ def from_dict(
f"or DataTree, got {type(root_data)}"
)

def depth(item) -> int:
pathstr, _ = item
return len(NodePath(pathstr).parts)
def depth(item: tuple[NodePath, object]) -> int:
node_path, _ = item
return len(node_path.parts)

if d_cast:
# Populate tree with children determined from data_objects mapping
if nodes:
# Populate tree with children
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
for path, data in sorted(d_cast.items(), key=depth):
for path, node in sorted(nodes.items(), key=depth):
# Create and set new node
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(dataset=data)
if isinstance(node, DataTree):
new_node = node.copy()
elif isinstance(node, Dataset) or node is None:
new_node = cls(dataset=node)
else:
raise TypeError(f"invalid values: {data}")
raise TypeError(f"invalid values: {node}")
obj._set_item(
path,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)

# TODO: figure out why mypy is raising an error here, likely something
# to do with the return type of Dataset.copy()
return obj # type: ignore[return-value]
return obj

def to_dict(self, relative: bool = False) -> dict[str, Dataset]:
"""
Expand Down
4 changes: 4 additions & 0 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self, *pathsegments):
)
# TODO should we also forbid suffixes to avoid node names with dots in them?

def absolute(self) -> Self:
"""Convert into an absolute path."""
return type(self)("/", *self.parts)


class TreeNode:
"""
Expand Down
4 changes: 4 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ...
def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ...


_T = TypeVar("_T")
NestedDict = dict[str, "NestedDict[_T] | _T"]
Comment thread
shoyer marked this conversation as resolved.


AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True)


Expand Down
21 changes: 20 additions & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
)

if TYPE_CHECKING:
from xarray.core.types import Dims, ErrorOptionsWithWarn
from xarray.core.types import Dims, ErrorOptionsWithWarn, NestedDict

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -335,6 +335,25 @@ def remove_incompatible_items(
del first_dict[k]


def flat_items(
nested: Mapping[str, NestedDict[T] | T],
prefix: str | None = None,
separator: str = "/",
) -> Iterable[tuple[str, T]]:
"""Yields flat items from a nested dictionary of dicts.

Notes:
- Only dict subclasses are flattened.
- Duplicate items are not removed. These should be checked separately.
"""
for key, value in nested.items():
key = prefix + separator + key if prefix is not None else key
if isinstance(value, dict):
yield from flat_items(value, key, separator)
else:
yield key, value


def is_full_slice(value: Any) -> bool:
return isinstance(value, slice) and value == slice(None)

Expand Down
Loading
Loading