Skip to content

Commit c204229

Browse files
committed
REF: Reuse async loading logic for DataTree and open_groups (Fixes #11131)
1 parent 8cf3ad7 commit c204229

File tree

1 file changed

+63
-58
lines changed

1 file changed

+63
-58
lines changed

xarray/backends/zarr.py

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,6 +1891,7 @@ def open_datatree(
18911891
else:
18921892
parent = str(NodePath("/"))
18931893

1894+
# Open stores synchronously to avoid nested event-loop issues
18941895
stores = ZarrStore.open_store(
18951896
filename_or_obj,
18961897
group=parent,
@@ -1908,8 +1909,8 @@ def open_datatree(
19081909
if _zarr_v3():
19091910
from zarr.core.sync import sync as zarr_sync
19101911

1911-
return zarr_sync(
1912-
self._open_datatree_from_stores_async(
1912+
groups_dict = zarr_sync(
1913+
self._open_groups_from_stores_async(
19131914
stores=stores,
19141915
parent=parent,
19151916
group=group,
@@ -1924,29 +1925,28 @@ def open_datatree(
19241925
)
19251926
)
19261927
else:
1927-
# Fallback for zarr v2: sequential loading
1928-
groups_dict = {}
1929-
for path_group, store in stores.items():
1930-
store_entrypoint = StoreBackendEntrypoint()
1931-
with close_on_error(store):
1932-
group_ds = store_entrypoint.open_dataset(
1933-
store,
1934-
mask_and_scale=mask_and_scale,
1935-
decode_times=decode_times,
1936-
concat_characters=concat_characters,
1937-
decode_coords=decode_coords,
1938-
drop_variables=drop_variables,
1939-
use_cftime=use_cftime,
1940-
decode_timedelta=decode_timedelta,
1941-
)
1942-
if group:
1943-
group_name = str(NodePath(path_group).relative_to(parent))
1944-
else:
1945-
group_name = str(NodePath(path_group))
1946-
groups_dict[group_name] = group_ds
1947-
return datatree_from_dict_with_io_cleanup(groups_dict)
1928+
groups_dict = self.open_groups_as_dict(
1929+
filename_or_obj,
1930+
mask_and_scale=mask_and_scale,
1931+
decode_times=decode_times,
1932+
concat_characters=concat_characters,
1933+
decode_coords=decode_coords,
1934+
drop_variables=drop_variables,
1935+
use_cftime=use_cftime,
1936+
decode_timedelta=decode_timedelta,
1937+
group=group,
1938+
mode=mode,
1939+
synchronizer=synchronizer,
1940+
consolidated=consolidated,
1941+
chunk_store=chunk_store,
1942+
storage_options=storage_options,
1943+
zarr_version=zarr_version,
1944+
zarr_format=zarr_format,
1945+
)
19481946

1949-
async def _open_datatree_from_stores_async(
1947+
return datatree_from_dict_with_io_cleanup(groups_dict)
1948+
1949+
async def _open_groups_from_stores_async(
19501950
self,
19511951
stores: dict,
19521952
parent: str,
@@ -1960,11 +1960,22 @@ async def _open_datatree_from_stores_async(
19601960
use_cftime=None,
19611961
decode_timedelta=None,
19621962
max_concurrency: int | None = None,
1963-
) -> DataTree:
1964-
"""Async helper to open datasets from pre-opened stores and create indexes.
1963+
) -> dict[str, Dataset]:
1964+
"""Shared async core: open datasets from pre-opened stores concurrently.
19651965
1966-
This method takes already-opened stores (avoiding nested zarr_sync() calls)
1967-
and runs the Dataset opening and index creation concurrently.
1966+
This takes already-opened stores (avoiding nested zarr_sync() calls)
1967+
and runs Dataset opening and index creation with bounded concurrency.
1968+
1969+
Parameters
1970+
----------
1971+
stores : dict
1972+
Mapping of group paths to already-opened ZarrStore instances.
1973+
parent : str
1974+
The resolved parent group path.
1975+
group : str or None
1976+
The user-requested group, used for relative path computation.
1977+
max_concurrency : int or None, optional
1978+
Maximum number of groups to open concurrently. Defaults to 10.
19681979
"""
19691980
from xarray.backends.api import _maybe_create_default_indexes_async
19701981

@@ -1986,7 +1997,7 @@ async def open_one(path_group: str, store) -> tuple[str, Dataset]:
19861997
use_cftime=use_cftime,
19871998
decode_timedelta=decode_timedelta,
19881999
)
1989-
# Create indexes in parallel (within this group)
2000+
# Create indexes concurrently
19902001
ds = await _maybe_create_default_indexes_async(ds)
19912002
if group:
19922003
group_name = str(NodePath(path_group).relative_to(parent))
@@ -2004,7 +2015,7 @@ async def collect_result(path_group: str, store) -> None:
20042015
for path_group, store in stores.items():
20052016
tg.create_task(collect_result(path_group, store))
20062017

2007-
return datatree_from_dict_with_io_cleanup(groups_dict)
2018+
return groups_dict
20082019

20092020
def open_groups_as_dict(
20102021
self,
@@ -2088,11 +2099,18 @@ async def open_groups_as_dict_async(
20882099
storage_options=None,
20892100
zarr_version=None,
20902101
zarr_format=None,
2102+
max_concurrency: int | None = None,
20912103
) -> dict[str, Dataset]:
20922104
"""Asynchronously open each group into a Dataset concurrently.
20932105
2094-
This mirrors open_groups_as_dict but parallelizes per-group Dataset opening,
2106+
This mirrors open_groups_as_dict but parallelizes per-group Dataset
2107+
opening with bounded concurrency and proper async index creation,
20952108
which can significantly reduce latency on high-RTT object stores.
2109+
2110+
Parameters
2111+
----------
2112+
max_concurrency : int or None, optional
2113+
Maximum number of groups to open concurrently. Defaults to 10.
20962114
"""
20972115
filename_or_obj = _normalize_path(filename_or_obj)
20982116

@@ -2115,40 +2133,27 @@ async def open_groups_as_dict_async(
21152133
zarr_format=zarr_format,
21162134
)
21172135

2118-
async def open_one(path_group: str, store) -> tuple[str, Dataset]:
2119-
store_entrypoint = StoreBackendEntrypoint()
2120-
2121-
def _load_sync():
2122-
with close_on_error(store):
2123-
return store_entrypoint.open_dataset(
2124-
store,
2125-
mask_and_scale=mask_and_scale,
2126-
decode_times=decode_times,
2127-
concat_characters=concat_characters,
2128-
decode_coords=decode_coords,
2129-
drop_variables=drop_variables,
2130-
use_cftime=use_cftime,
2131-
decode_timedelta=decode_timedelta,
2132-
)
2133-
2134-
ds = await asyncio.to_thread(_load_sync)
2135-
if group:
2136-
group_name = str(NodePath(path_group).relative_to(parent))
2137-
else:
2138-
group_name = str(NodePath(path_group))
2139-
return group_name, ds
2140-
2141-
tasks = [open_one(path_group, store) for path_group, store in stores.items()]
2142-
results = await asyncio.gather(*tasks)
2143-
return dict(results)
2136+
return await self._open_groups_from_stores_async(
2137+
stores=stores,
2138+
parent=parent,
2139+
group=group,
2140+
mask_and_scale=mask_and_scale,
2141+
decode_times=decode_times,
2142+
concat_characters=concat_characters,
2143+
decode_coords=decode_coords,
2144+
drop_variables=drop_variables,
2145+
use_cftime=use_cftime,
2146+
decode_timedelta=decode_timedelta,
2147+
max_concurrency=max_concurrency,
2148+
)
21442149

21452150

21462151
def _build_group_members(
21472152
zarr_group: ZarrGroup,
21482153
group_paths: list[str],
21492154
parent: str | None,
21502155
) -> dict[str, ZarrGroup]:
2151-
parent = parent if parent else "/"
2156+
parent = parent or "/"
21522157
group_members: dict[str, ZarrGroup] = {}
21532158

21542159
for path in group_paths:

0 commit comments

Comments
 (0)