Skip to content

Commit ab0cbfa

Browse files
committed
Address PR review: TaskGroup, max_concurrency, and open_dataset_async
- Replace asyncio.gather with asyncio.TaskGroup for better error handling (cancels outstanding tasks on error) - Add max_concurrency parameter to open_datatree for controlling parallel I/O operations (defaults to 10) - Add StoreBackendEntrypoint.open_dataset_async method - Add test for open_dataset_async equivalence
1 parent 87fd361 commit ab0cbfa

5 files changed

Lines changed: 99 additions & 22 deletions

File tree

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ v2026.02.0 (unreleased)
1414
New Features
1515
~~~~~~~~~~~~
1616

17+
- Added ``max_concurrency`` parameter to :py:func:`open_datatree` to control
18+
the maximum number of concurrent I/O operations when opening groups in parallel
19+
with the Zarr backend (:pull:`10742`).
20+
By `Alfonso Ladino <https://github.com/aladinor>`_.
1721

1822
Breaking Changes
1923
~~~~~~~~~~~~~~~~

xarray/backends/api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,7 @@ def open_datatree(
942942
chunked_array_type: str | None = None,
943943
from_array_kwargs: dict[str, Any] | None = None,
944944
backend_kwargs: dict[str, Any] | None = None,
945+
max_concurrency: int | None = None,
945946
**kwargs,
946947
) -> DataTree:
947948
"""
@@ -1074,6 +1075,13 @@ def open_datatree(
10741075
chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg.
10751076
For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed
10761077
to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
1078+
max_concurrency : int, optional
1079+
Maximum number of concurrent I/O operations when opening groups in
1080+
parallel. This limits the number of groups that are loaded simultaneously.
1081+
Useful for controlling resource usage with large datatrees or stores
1082+
that may have limitations on concurrent access (e.g., icechunk).
1083+
Only used by backends that support parallel loading (currently Zarr v3).
1084+
If None (default), the backend uses its default value (typically 10).
10771085
backend_kwargs: dict
10781086
Additional keyword arguments passed on to the engine open function,
10791087
equivalent to `**kwargs`.
@@ -1134,6 +1142,9 @@ def open_datatree(
11341142
)
11351143
overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
11361144

1145+
if max_concurrency is not None:
1146+
kwargs["max_concurrency"] = max_concurrency
1147+
11371148
backend_tree = backend.open_datatree(
11381149
filename_or_obj,
11391150
drop_variables=drop_variables,

xarray/backends/store.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from collections.abc import Iterable
45
from typing import TYPE_CHECKING
56

@@ -72,5 +73,37 @@ def open_dataset(
7273

7374
return ds
7475

76+
async def open_dataset_async(
77+
self,
78+
filename_or_obj: T_PathFileOrDataStore,
79+
*,
80+
mask_and_scale=True,
81+
decode_times=True,
82+
concat_characters=True,
83+
decode_coords=True,
84+
drop_variables: str | Iterable[str] | None = None,
85+
set_indexes: bool = True,
86+
use_cftime=None,
87+
decode_timedelta=None,
88+
) -> Dataset:
89+
"""Async version of open_dataset.
90+
91+
Offloads the entire open_dataset operation to a thread to avoid blocking
92+
the event loop. This is necessary because decode_cf_variables can trigger
93+
data reads (e.g., for time decoding) which may use synchronous I/O.
94+
"""
95+
return await asyncio.to_thread(
96+
self.open_dataset,
97+
filename_or_obj,
98+
mask_and_scale=mask_and_scale,
99+
decode_times=decode_times,
100+
concat_characters=concat_characters,
101+
decode_coords=decode_coords,
102+
drop_variables=drop_variables,
103+
set_indexes=set_indexes,
104+
use_cftime=use_cftime,
105+
decode_timedelta=decode_timedelta,
106+
)
107+
75108

76109
BACKEND_ENTRYPOINTS["store"] = (None, StoreBackendEntrypoint)

xarray/backends/zarr.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,7 @@ def open_datatree(
18791879
storage_options=None,
18801880
zarr_version=None,
18811881
zarr_format=None,
1882+
max_concurrency: int | None = None,
18821883
) -> DataTree:
18831884
filename_or_obj = _normalize_path(filename_or_obj)
18841885

@@ -1916,6 +1917,7 @@ def open_datatree(
19161917
drop_variables=drop_variables,
19171918
use_cftime=use_cftime,
19181919
decode_timedelta=decode_timedelta,
1920+
max_concurrency=max_concurrency,
19191921
)
19201922
)
19211923
else:
@@ -1954,6 +1956,7 @@ async def _open_datatree_from_stores_async(
19541956
drop_variables: str | Iterable[str] | None = None,
19551957
use_cftime=None,
19561958
decode_timedelta=None,
1959+
max_concurrency: int | None = None,
19571960
) -> DataTree:
19581961
"""Async helper to open datasets from pre-opened stores and create indexes.
19591962
@@ -1962,28 +1965,24 @@ async def _open_datatree_from_stores_async(
19621965
"""
19631966
from xarray.backends.api import _maybe_create_default_indexes_async
19641967

1965-
# Limit concurrent to_thread calls to avoid deadlocks with some stores
1966-
# (e.g., icechunk can deadlock when too many threads access it simultaneously)
1967-
sem = asyncio.Semaphore(10)
1968+
if max_concurrency is None:
1969+
max_concurrency = 10
1970+
sem = asyncio.Semaphore(max_concurrency)
19681971

19691972
async def open_one(path_group: str, store) -> tuple[str, Dataset]:
19701973
async with sem:
19711974
store_entrypoint = StoreBackendEntrypoint()
1972-
1973-
def _load_sync():
1974-
with close_on_error(store):
1975-
return store_entrypoint.open_dataset(
1976-
store,
1977-
mask_and_scale=mask_and_scale,
1978-
decode_times=decode_times,
1979-
concat_characters=concat_characters,
1980-
decode_coords=decode_coords,
1981-
drop_variables=drop_variables,
1982-
use_cftime=use_cftime,
1983-
decode_timedelta=decode_timedelta,
1984-
)
1985-
1986-
ds = await asyncio.to_thread(_load_sync)
1975+
with close_on_error(store):
1976+
ds = await store_entrypoint.open_dataset_async(
1977+
store,
1978+
mask_and_scale=mask_and_scale,
1979+
decode_times=decode_times,
1980+
concat_characters=concat_characters,
1981+
decode_coords=decode_coords,
1982+
drop_variables=drop_variables,
1983+
use_cftime=use_cftime,
1984+
decode_timedelta=decode_timedelta,
1985+
)
19871986
# Create indexes in parallel (within this group)
19881987
ds = await _maybe_create_default_indexes_async(ds)
19891988
if group:
@@ -1992,10 +1991,15 @@ def _load_sync():
19921991
group_name = str(NodePath(path_group))
19931992
return group_name, ds
19941993

1995-
# Open all datasets and create indexes concurrently
1996-
tasks = [open_one(path_group, store) for path_group, store in stores.items()]
1997-
results = await asyncio.gather(*tasks)
1998-
groups_dict = dict(results)
1994+
groups_dict: dict[str, Dataset] = {}
1995+
1996+
async def collect_result(path_group: str, store) -> None:
1997+
group_name, ds = await open_one(path_group, store)
1998+
groups_dict[group_name] = ds
1999+
2000+
async with asyncio.TaskGroup() as tg:
2001+
for path_group, store in stores.items():
2002+
tg.create_task(collect_result(path_group, store))
19992003

20002004
return datatree_from_dict_with_io_cleanup(groups_dict)
20012005

xarray/tests/test_backends_zarr_async.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import xarray as xr
1111
from xarray.backends.api import _maybe_create_default_indexes_async
12+
from xarray.backends.store import StoreBackendEntrypoint
1213
from xarray.backends.zarr import ZarrBackendEntrypoint
1314
from xarray.testing import assert_equal
1415
from xarray.tests import (
@@ -232,3 +233,27 @@ def test_sync_open_datatree_uses_async_internally(self, zarr_format):
232233
# For zarr v3, the async function should be called
233234
assert mock_async.call_count > 0
234235
assert_equal(dtree, dtree_loaded)
236+
237+
@pytest.mark.asyncio
238+
@requires_zarr_v3
239+
@parametrize_zarr_format
240+
async def test_store_backend_open_dataset_async_equivalence(self, zarr_format):
241+
"""Test that StoreBackendEntrypoint.open_dataset_async returns same result as sync."""
242+
from xarray.backends.zarr import ZarrStore
243+
244+
ds = create_dataset_with_coordinates(n_coords=3)
245+
246+
with self.create_zarr_store() as store:
247+
ds.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format)
248+
249+
zarr_store = ZarrStore.open_group(
250+
store,
251+
consolidated=False,
252+
zarr_format=zarr_format,
253+
)
254+
255+
store_entrypoint = StoreBackendEntrypoint()
256+
ds_sync = store_entrypoint.open_dataset(zarr_store)
257+
ds_async = await store_entrypoint.open_dataset_async(zarr_store)
258+
259+
assert_equal(ds_sync, ds_async)

0 commit comments

Comments
 (0)