Skip to content

Commit 89c913a

Browse files
authored
Support compute=False from DataTree.to_netcdf (#10625)
* Refactor to_netcdf() and to_zarr() internals * Fixes per review * Clean up comments * Fix type for to_netcdf() * Add test and whats-new for cross-group redundant computation * Fix test failure on CI (and add a better test) * grammar
1 parent 5addf47 commit 89c913a

11 files changed

Lines changed: 444 additions & 133 deletions

File tree

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ v2025.08.1 (unreleased)
1313
New Features
1414
~~~~~~~~~~~~
1515

16+
- ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and
17+
:py:meth:`DataTree.to_zarr`.
18+
By `Stephan Hoyer <https://github.com/shoyer>`_.
1619

1720
Breaking changes
1821
~~~~~~~~~~~~~~~~
@@ -29,6 +32,10 @@ Bug fixes
2932
- Warn instead of raise in case of misconfiguration of ``unlimited_dims`` originating from dataset.encoding, to prevent breaking users workflows (:issue:`10647`, :pull:`10648`).
3033
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
3134

35+
- :py:meth:`DataTree.to_netcdf` and :py:meth:`DataTree.to_zarr` now avoid
36+
redundant computation of Dask arrays with cross-group dependencies
37+
(:issue:`10637`).
38+
By `Stephan Hoyer <https://github.com/shoyer>`_.
3239

3340
Documentation
3441
~~~~~~~~~~~~~

xarray/backends/api.py

Lines changed: 136 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@
3131
from xarray.backends import plugins
3232
from xarray.backends.common import (
3333
AbstractDataStore,
34+
AbstractWritableDataStore,
3435
ArrayWriter,
3536
BytesIOProxy,
3637
T_PathFileOrDataStore,
3738
_find_absolute_paths,
3839
_normalize_path,
3940
)
40-
from xarray.backends.locks import _get_scheduler
41+
from xarray.backends.locks import get_dask_scheduler
4142
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
4243
from xarray.core import dtypes, indexing
4344
from xarray.core.coordinates import Coordinates
@@ -310,12 +311,18 @@ def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None:
310311
_protect_dataset_variables_inplace(node.dataset, cache)
311312

312313

313-
def _finalize_store(write, store):
314+
def _finalize_store(writes, store):
314315
"""Finalize this store by explicitly syncing and closing"""
315-
del write # ensure writing is done first
316+
del writes # ensure writing is done first
316317
store.close()
317318

318319

320+
def delayed_close_after_writes(writes, store):
321+
import dask
322+
323+
return dask.delayed(_finalize_store)(writes, store)
324+
325+
319326
def _multi_file_closer(closers):
320327
for closer in closers:
321328
closer()
@@ -1858,6 +1865,39 @@ def open_mfdataset(
18581865
}
18591866

18601867

1868+
def get_writable_netcdf_store(
1869+
target,
1870+
engine: T_NetcdfEngine,
1871+
*,
1872+
format: T_NetcdfTypes | None,
1873+
mode: NetcdfWriteModes,
1874+
autoclose: bool,
1875+
invalid_netcdf: bool,
1876+
auto_complex: bool | None,
1877+
) -> AbstractWritableDataStore:
1878+
"""Create a store for writing to a netCDF file."""
1879+
try:
1880+
store_open = WRITEABLE_STORES[engine]
1881+
except KeyError as err:
1882+
raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err
1883+
1884+
if format is not None:
1885+
format = format.upper() # type: ignore[assignment]
1886+
1887+
kwargs = dict(autoclose=True) if autoclose else {}
1888+
if invalid_netcdf:
1889+
if engine == "h5netcdf":
1890+
kwargs["invalid_netcdf"] = invalid_netcdf
1891+
else:
1892+
raise ValueError(
1893+
f"unrecognized option 'invalid_netcdf' for engine {engine}"
1894+
)
1895+
if auto_complex is not None:
1896+
kwargs["auto_complex"] = auto_complex
1897+
1898+
return store_open(target, mode=mode, format=format, **kwargs)
1899+
1900+
18611901
# multifile=True returns writer and datastore
18621902
@overload
18631903
def to_netcdf(
@@ -2043,16 +2083,8 @@ def to_netcdf(
20432083
# sanitize unlimited_dims
20442084
unlimited_dims = _sanitize_unlimited_dims(dataset, unlimited_dims)
20452085

2046-
try:
2047-
store_open = WRITEABLE_STORES[engine]
2048-
except KeyError as err:
2049-
raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err
2050-
2051-
if format is not None:
2052-
format = format.upper() # type: ignore[assignment]
2053-
20542086
# handle scheduler specific logic
2055-
scheduler = _get_scheduler()
2087+
scheduler = get_dask_scheduler()
20562088
have_chunks = any(v.chunks is not None for v in dataset.variables.values())
20572089

20582090
autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"]
@@ -2067,18 +2099,17 @@ def to_netcdf(
20672099
else:
20682100
target = path_or_file # type: ignore[assignment]
20692101

2070-
kwargs = dict(autoclose=True) if autoclose else {}
2071-
if invalid_netcdf:
2072-
if engine == "h5netcdf":
2073-
kwargs["invalid_netcdf"] = invalid_netcdf
2074-
else:
2075-
raise ValueError(
2076-
f"unrecognized option 'invalid_netcdf' for engine {engine}"
2077-
)
2078-
if auto_complex is not None:
2079-
kwargs["auto_complex"] = auto_complex
2080-
2081-
store = store_open(target, mode, format, group, **kwargs)
2102+
store = get_writable_netcdf_store(
2103+
target,
2104+
engine,
2105+
mode=mode,
2106+
format=format,
2107+
autoclose=autoclose,
2108+
invalid_netcdf=invalid_netcdf,
2109+
auto_complex=auto_complex,
2110+
)
2111+
if group is not None:
2112+
store = store.get_child_store(group)
20822113

20832114
writer = ArrayWriter()
20842115

@@ -2099,17 +2130,18 @@ def to_netcdf(
20992130
writes = writer.sync(compute=compute)
21002131

21012132
finally:
2102-
if not multifile and compute: # type: ignore[redundant-expr]
2103-
store.close()
2133+
if not multifile:
2134+
if compute:
2135+
store.close()
2136+
else:
2137+
store.sync()
21042138

21052139
if path_or_file is None:
21062140
assert isinstance(target, BytesIOProxy) # created in this function
21072141
return target.getvalue_or_getbuffer()
21082142

21092143
if not compute:
2110-
import dask
2111-
2112-
return dask.delayed(_finalize_store)(writes, store)
2144+
return delayed_close_after_writes(writes, store)
21132145

21142146
return None
21152147

@@ -2265,20 +2297,71 @@ def save_mfdataset(
22652297
try:
22662298
writes = [w.sync(compute=compute) for w in writers]
22672299
finally:
2268-
if compute:
2269-
for store in stores:
2300+
for store in stores:
2301+
if compute:
22702302
store.close()
2303+
else:
2304+
store.sync()
22712305

22722306
if not compute:
22732307
import dask
22742308

22752309
return dask.delayed(
2276-
list(
2277-
starmap(dask.delayed(_finalize_store), zip(writes, stores, strict=True))
2278-
)
2310+
list(starmap(delayed_close_after_writes, zip(writes, stores, strict=True)))
22792311
)
22802312

22812313

2314+
def get_writable_zarr_store(
2315+
store: ZarrStoreLike | None = None,
2316+
*,
2317+
chunk_store: MutableMapping | str | os.PathLike | None = None,
2318+
mode: ZarrWriteModes | None = None,
2319+
synchronizer=None,
2320+
group: str | None = None,
2321+
consolidated: bool | None = None,
2322+
append_dim: Hashable | None = None,
2323+
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
2324+
safe_chunks: bool = True,
2325+
align_chunks: bool = False,
2326+
storage_options: dict[str, str] | None = None,
2327+
zarr_version: int | None = None,
2328+
zarr_format: int | None = None,
2329+
write_empty_chunks: bool | None = None,
2330+
) -> backends.ZarrStore:
2331+
"""Create a store for writing to Zarr."""
2332+
from xarray.backends.zarr import _choose_default_mode, _get_mappers
2333+
2334+
kwargs, mapper, chunk_mapper = _get_mappers(
2335+
storage_options=storage_options, store=store, chunk_store=chunk_store
2336+
)
2337+
mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region)
2338+
2339+
if mode == "r+":
2340+
already_consolidated = consolidated
2341+
consolidate_on_close = False
2342+
else:
2343+
already_consolidated = False
2344+
consolidate_on_close = consolidated or consolidated is None
2345+
2346+
return backends.ZarrStore.open_group(
2347+
store=mapper,
2348+
mode=mode,
2349+
synchronizer=synchronizer,
2350+
group=group,
2351+
consolidated=already_consolidated,
2352+
consolidate_on_close=consolidate_on_close,
2353+
chunk_store=chunk_mapper,
2354+
append_dim=append_dim,
2355+
write_region=region,
2356+
safe_chunks=safe_chunks,
2357+
align_chunks=align_chunks,
2358+
zarr_version=zarr_version,
2359+
zarr_format=zarr_format,
2360+
write_empty=write_empty_chunks,
2361+
**kwargs,
2362+
)
2363+
2364+
22822365
# compute=True returns ZarrStore
22832366
@overload
22842367
def to_zarr(
@@ -2353,7 +2436,6 @@ def to_zarr(
23532436
23542437
See `Dataset.to_zarr` for full API docs.
23552438
"""
2356-
from xarray.backends.zarr import _choose_default_mode, _get_mappers
23572439

23582440
# validate Dataset keys, DataArray names
23592441
_validate_dataset_names(dataset)
@@ -2368,53 +2450,39 @@ def to_zarr(
23682450
if encoding is None:
23692451
encoding = {}
23702452

2371-
kwargs, mapper, chunk_mapper = _get_mappers(
2372-
storage_options=storage_options, store=store, chunk_store=chunk_store
2373-
)
2374-
mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region)
2375-
2376-
if mode == "r+":
2377-
already_consolidated = consolidated
2378-
consolidate_on_close = False
2379-
else:
2380-
already_consolidated = False
2381-
consolidate_on_close = consolidated or consolidated is None
2382-
2383-
zstore = backends.ZarrStore.open_group(
2384-
store=mapper,
2453+
zstore = get_writable_zarr_store(
2454+
store,
2455+
chunk_store=chunk_store,
23852456
mode=mode,
23862457
synchronizer=synchronizer,
23872458
group=group,
2388-
consolidated=already_consolidated,
2389-
consolidate_on_close=consolidate_on_close,
2390-
chunk_store=chunk_mapper,
2459+
consolidated=consolidated,
23912460
append_dim=append_dim,
2392-
write_region=region,
2461+
region=region,
23932462
safe_chunks=safe_chunks,
23942463
align_chunks=align_chunks,
2464+
storage_options=storage_options,
23952465
zarr_version=zarr_version,
23962466
zarr_format=zarr_format,
2397-
write_empty=write_empty_chunks,
2398-
**kwargs,
2467+
write_empty_chunks=write_empty_chunks,
23992468
)
24002469

2401-
dataset = zstore._validate_and_autodetect_region(
2402-
dataset,
2403-
)
2470+
dataset = zstore._validate_and_autodetect_region(dataset)
24042471
zstore._validate_encoding(encoding)
24052472

24062473
writer = ArrayWriter()
2407-
# TODO: figure out how to properly handle unlimited_dims
2408-
dump_to_store(dataset, zstore, writer, encoding=encoding)
2409-
writes = writer.sync(
2410-
compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
2411-
)
24122474

2413-
if compute:
2414-
_finalize_store(writes, zstore)
2415-
else:
2416-
import dask
2475+
# TODO: figure out how to properly handle unlimited_dims
2476+
try:
2477+
dump_to_store(dataset, zstore, writer, encoding=encoding)
2478+
writes = writer.sync(
2479+
compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
2480+
)
2481+
finally:
2482+
if compute:
2483+
zstore.close()
24172484

2418-
return dask.delayed(_finalize_store)(writes, zstore)
2485+
if not compute:
2486+
return delayed_close_after_writes(writes, zstore)
24192487

24202488
return zstore

xarray/backends/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Any,
1313
ClassVar,
1414
Generic,
15+
Self,
1516
TypeVar,
1617
Union,
1718
overload,
@@ -326,6 +327,10 @@ async def async_get_duck_array(self, dtype: np.typing.DTypeLike = None):
326327
class AbstractDataStore:
327328
__slots__ = ()
328329

330+
def get_child_store(self, group: str) -> Self: # pragma: no cover
331+
"""Get a store corresponding to the indicated child group."""
332+
raise NotImplementedError()
333+
329334
def get_dimensions(self): # pragma: no cover
330335
raise NotImplementedError()
331336

@@ -606,6 +611,10 @@ def set_dimensions(self, variables, unlimited_dims=None):
606611
is_unlimited = dim in unlimited_dims
607612
self.set_dimension(dim, length, is_unlimited)
608613

614+
def sync(self):
615+
"""Write all buffered data to disk."""
616+
raise NotImplementedError()
617+
609618

610619
def _infer_dtype(array, name=None):
611620
"""Given an object array with no missing values, infer its dtype from all elements."""

xarray/backends/h5netcdf_.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import io
55
import os
66
from collections.abc import Iterable
7-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, Self
88

99
import numpy as np
1010

@@ -150,6 +150,17 @@ def __init__(
150150
self.lock = ensure_lock(lock)
151151
self.autoclose = autoclose
152152

153+
def get_child_store(self, group: str) -> Self:
154+
if self._group is not None:
155+
group = os.path.join(self._group, group)
156+
return type(self)(
157+
self._manager,
158+
group=group,
159+
mode=self._mode,
160+
lock=self.lock,
161+
autoclose=self.autoclose,
162+
)
163+
153164
@classmethod
154165
def open(
155166
cls,

0 commit comments

Comments
 (0)