@@ -1879,6 +1879,7 @@ def open_datatree(
18791879 storage_options = None ,
18801880 zarr_version = None ,
18811881 zarr_format = None ,
1882+ max_concurrency : int | None = None ,
18821883 ) -> DataTree :
18831884 filename_or_obj = _normalize_path (filename_or_obj )
18841885
@@ -1916,6 +1917,7 @@ def open_datatree(
19161917 drop_variables = drop_variables ,
19171918 use_cftime = use_cftime ,
19181919 decode_timedelta = decode_timedelta ,
1920+ max_concurrency = max_concurrency ,
19191921 )
19201922 )
19211923 else :
@@ -1954,6 +1956,7 @@ async def _open_datatree_from_stores_async(
19541956 drop_variables : str | Iterable [str ] | None = None ,
19551957 use_cftime = None ,
19561958 decode_timedelta = None ,
1959+ max_concurrency : int | None = None ,
19571960 ) -> DataTree :
19581961 """Async helper to open datasets from pre-opened stores and create indexes.
19591962
@@ -1962,28 +1965,24 @@ async def _open_datatree_from_stores_async(
19621965 """
19631966 from xarray .backends .api import _maybe_create_default_indexes_async
19641967
1965- # Limit concurrent to_thread calls to avoid deadlocks with some stores
1966- # (e.g., icechunk can deadlock when too many threads access it simultaneously)
1967- sem = asyncio .Semaphore (10 )
1968+ if max_concurrency is None :
1969+ max_concurrency = 10
1970+ sem = asyncio .Semaphore (max_concurrency )
19681971
19691972 async def open_one (path_group : str , store ) -> tuple [str , Dataset ]:
19701973 async with sem :
19711974 store_entrypoint = StoreBackendEntrypoint ()
1972-
1973- def _load_sync ():
1974- with close_on_error (store ):
1975- return store_entrypoint .open_dataset (
1976- store ,
1977- mask_and_scale = mask_and_scale ,
1978- decode_times = decode_times ,
1979- concat_characters = concat_characters ,
1980- decode_coords = decode_coords ,
1981- drop_variables = drop_variables ,
1982- use_cftime = use_cftime ,
1983- decode_timedelta = decode_timedelta ,
1984- )
1985-
1986- ds = await asyncio .to_thread (_load_sync )
1975+ with close_on_error (store ):
1976+ ds = await store_entrypoint .open_dataset_async (
1977+ store ,
1978+ mask_and_scale = mask_and_scale ,
1979+ decode_times = decode_times ,
1980+ concat_characters = concat_characters ,
1981+ decode_coords = decode_coords ,
1982+ drop_variables = drop_variables ,
1983+ use_cftime = use_cftime ,
1984+ decode_timedelta = decode_timedelta ,
1985+ )
19871986 # Create indexes in parallel (within this group)
19881987 ds = await _maybe_create_default_indexes_async (ds )
19891988 if group :
@@ -1992,10 +1991,15 @@ def _load_sync():
19921991 group_name = str (NodePath (path_group ))
19931992 return group_name , ds
19941993
1995- # Open all datasets and create indexes concurrently
1996- tasks = [open_one (path_group , store ) for path_group , store in stores .items ()]
1997- results = await asyncio .gather (* tasks )
1998- groups_dict = dict (results )
1994+ groups_dict : dict [str , Dataset ] = {}
1995+
1996+ async def collect_result (path_group : str , store ) -> None :
1997+ group_name , ds = await open_one (path_group , store )
1998+ groups_dict [group_name ] = ds
1999+
2000+ async with asyncio .TaskGroup () as tg :
2001+ for path_group , store in stores .items ():
2002+ tg .create_task (collect_result (path_group , store ))
19992003
20002004 return datatree_from_dict_with_io_cleanup (groups_dict )
20012005
0 commit comments