3131from xarray .backends import plugins
3232from xarray .backends .common import (
3333 AbstractDataStore ,
34+ AbstractWritableDataStore ,
3435 ArrayWriter ,
3536 BytesIOProxy ,
3637 T_PathFileOrDataStore ,
3738 _find_absolute_paths ,
3839 _normalize_path ,
3940)
40- from xarray .backends .locks import _get_scheduler
41+ from xarray .backends .locks import get_dask_scheduler
4142from xarray .coders import CFDatetimeCoder , CFTimedeltaCoder
4243from xarray .core import dtypes , indexing
4344from xarray .core .coordinates import Coordinates
@@ -310,12 +311,18 @@ def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None:
310311 _protect_dataset_variables_inplace (node .dataset , cache )
311312
312313
313- def _finalize_store (write , store ):
314+ def _finalize_store (writes , store ):
314315 """Finalize this store by explicitly syncing and closing"""
315- del write # ensure writing is done first
316+ del writes # ensure writing is done first
316317 store .close ()
317318
318319
320+ def delayed_close_after_writes (writes , store ):
321+ import dask
322+
323+ return dask .delayed (_finalize_store )(writes , store )
324+
325+
319326def _multi_file_closer (closers ):
320327 for closer in closers :
321328 closer ()
@@ -1858,6 +1865,39 @@ def open_mfdataset(
18581865}
18591866
18601867
1868+ def get_writable_netcdf_store (
1869+ target ,
1870+ engine : T_NetcdfEngine ,
1871+ * ,
1872+ format : T_NetcdfTypes | None ,
1873+ mode : NetcdfWriteModes ,
1874+ autoclose : bool ,
1875+ invalid_netcdf : bool ,
1876+ auto_complex : bool | None ,
1877+ ) -> AbstractWritableDataStore :
1878+ """Create a store for writing to a netCDF file."""
1879+ try :
1880+ store_open = WRITEABLE_STORES [engine ]
1881+ except KeyError as err :
1882+ raise ValueError (f"unrecognized engine for to_netcdf: { engine !r} " ) from err
1883+
1884+ if format is not None :
1885+ format = format .upper () # type: ignore[assignment]
1886+
1887+ kwargs = dict (autoclose = True ) if autoclose else {}
1888+ if invalid_netcdf :
1889+ if engine == "h5netcdf" :
1890+ kwargs ["invalid_netcdf" ] = invalid_netcdf
1891+ else :
1892+ raise ValueError (
1893+ f"unrecognized option 'invalid_netcdf' for engine { engine } "
1894+ )
1895+ if auto_complex is not None :
1896+ kwargs ["auto_complex" ] = auto_complex
1897+
1898+ return store_open (target , mode = mode , format = format , ** kwargs )
1899+
1900+
18611901# multifile=True returns writer and datastore
18621902@overload
18631903def to_netcdf (
@@ -2043,16 +2083,8 @@ def to_netcdf(
20432083 # sanitize unlimited_dims
20442084 unlimited_dims = _sanitize_unlimited_dims (dataset , unlimited_dims )
20452085
2046- try :
2047- store_open = WRITEABLE_STORES [engine ]
2048- except KeyError as err :
2049- raise ValueError (f"unrecognized engine for to_netcdf: { engine !r} " ) from err
2050-
2051- if format is not None :
2052- format = format .upper () # type: ignore[assignment]
2053-
20542086 # handle scheduler specific logic
2055- scheduler = _get_scheduler ()
2087+ scheduler = get_dask_scheduler ()
20562088 have_chunks = any (v .chunks is not None for v in dataset .variables .values ())
20572089
20582090 autoclose = have_chunks and scheduler in ["distributed" , "multiprocessing" ]
@@ -2067,18 +2099,17 @@ def to_netcdf(
20672099 else :
20682100 target = path_or_file # type: ignore[assignment]
20692101
2070- kwargs = dict (autoclose = True ) if autoclose else {}
2071- if invalid_netcdf :
2072- if engine == "h5netcdf" :
2073- kwargs ["invalid_netcdf" ] = invalid_netcdf
2074- else :
2075- raise ValueError (
2076- f"unrecognized option 'invalid_netcdf' for engine { engine } "
2077- )
2078- if auto_complex is not None :
2079- kwargs ["auto_complex" ] = auto_complex
2080-
2081- store = store_open (target , mode , format , group , ** kwargs )
2102+ store = get_writable_netcdf_store (
2103+ target ,
2104+ engine ,
2105+ mode = mode ,
2106+ format = format ,
2107+ autoclose = autoclose ,
2108+ invalid_netcdf = invalid_netcdf ,
2109+ auto_complex = auto_complex ,
2110+ )
2111+ if group is not None :
2112+ store = store .get_child_store (group )
20822113
20832114 writer = ArrayWriter ()
20842115
@@ -2099,17 +2130,18 @@ def to_netcdf(
20992130 writes = writer .sync (compute = compute )
21002131
21012132 finally :
2102- if not multifile and compute : # type: ignore[redundant-expr]
2103- store .close ()
2133+ if not multifile :
2134+ if compute :
2135+ store .close ()
2136+ else :
2137+ store .sync ()
21042138
21052139 if path_or_file is None :
21062140 assert isinstance (target , BytesIOProxy ) # created in this function
21072141 return target .getvalue_or_getbuffer ()
21082142
21092143 if not compute :
2110- import dask
2111-
2112- return dask .delayed (_finalize_store )(writes , store )
2144+ return delayed_close_after_writes (writes , store )
21132145
21142146 return None
21152147
@@ -2265,20 +2297,71 @@ def save_mfdataset(
22652297 try :
22662298 writes = [w .sync (compute = compute ) for w in writers ]
22672299 finally :
2268- if compute :
2269- for store in stores :
2300+ for store in stores :
2301+ if compute :
22702302 store .close ()
2303+ else :
2304+ store .sync ()
22712305
22722306 if not compute :
22732307 import dask
22742308
22752309 return dask .delayed (
2276- list (
2277- starmap (dask .delayed (_finalize_store ), zip (writes , stores , strict = True ))
2278- )
2310+ list (starmap (delayed_close_after_writes , zip (writes , stores , strict = True )))
22792311 )
22802312
22812313
2314+ def get_writable_zarr_store (
2315+ store : ZarrStoreLike | None = None ,
2316+ * ,
2317+ chunk_store : MutableMapping | str | os .PathLike | None = None ,
2318+ mode : ZarrWriteModes | None = None ,
2319+ synchronizer = None ,
2320+ group : str | None = None ,
2321+ consolidated : bool | None = None ,
2322+ append_dim : Hashable | None = None ,
2323+ region : Mapping [str , slice | Literal ["auto" ]] | Literal ["auto" ] | None = None ,
2324+ safe_chunks : bool = True ,
2325+ align_chunks : bool = False ,
2326+ storage_options : dict [str , str ] | None = None ,
2327+ zarr_version : int | None = None ,
2328+ zarr_format : int | None = None ,
2329+ write_empty_chunks : bool | None = None ,
2330+ ) -> backends .ZarrStore :
2331+ """Create a store for writing to Zarr."""
2332+ from xarray .backends .zarr import _choose_default_mode , _get_mappers
2333+
2334+ kwargs , mapper , chunk_mapper = _get_mappers (
2335+ storage_options = storage_options , store = store , chunk_store = chunk_store
2336+ )
2337+ mode = _choose_default_mode (mode = mode , append_dim = append_dim , region = region )
2338+
2339+ if mode == "r+" :
2340+ already_consolidated = consolidated
2341+ consolidate_on_close = False
2342+ else :
2343+ already_consolidated = False
2344+ consolidate_on_close = consolidated or consolidated is None
2345+
2346+ return backends .ZarrStore .open_group (
2347+ store = mapper ,
2348+ mode = mode ,
2349+ synchronizer = synchronizer ,
2350+ group = group ,
2351+ consolidated = already_consolidated ,
2352+ consolidate_on_close = consolidate_on_close ,
2353+ chunk_store = chunk_mapper ,
2354+ append_dim = append_dim ,
2355+ write_region = region ,
2356+ safe_chunks = safe_chunks ,
2357+ align_chunks = align_chunks ,
2358+ zarr_version = zarr_version ,
2359+ zarr_format = zarr_format ,
2360+ write_empty = write_empty_chunks ,
2361+ ** kwargs ,
2362+ )
2363+
2364+
22822365# compute=True returns ZarrStore
22832366@overload
22842367def to_zarr (
@@ -2353,7 +2436,6 @@ def to_zarr(
23532436
23542437 See `Dataset.to_zarr` for full API docs.
23552438 """
2356- from xarray .backends .zarr import _choose_default_mode , _get_mappers
23572439
23582440 # validate Dataset keys, DataArray names
23592441 _validate_dataset_names (dataset )
@@ -2368,53 +2450,39 @@ def to_zarr(
23682450 if encoding is None :
23692451 encoding = {}
23702452
2371- kwargs , mapper , chunk_mapper = _get_mappers (
2372- storage_options = storage_options , store = store , chunk_store = chunk_store
2373- )
2374- mode = _choose_default_mode (mode = mode , append_dim = append_dim , region = region )
2375-
2376- if mode == "r+" :
2377- already_consolidated = consolidated
2378- consolidate_on_close = False
2379- else :
2380- already_consolidated = False
2381- consolidate_on_close = consolidated or consolidated is None
2382-
2383- zstore = backends .ZarrStore .open_group (
2384- store = mapper ,
2453+ zstore = get_writable_zarr_store (
2454+ store ,
2455+ chunk_store = chunk_store ,
23852456 mode = mode ,
23862457 synchronizer = synchronizer ,
23872458 group = group ,
2388- consolidated = already_consolidated ,
2389- consolidate_on_close = consolidate_on_close ,
2390- chunk_store = chunk_mapper ,
2459+ consolidated = consolidated ,
23912460 append_dim = append_dim ,
2392- write_region = region ,
2461+ region = region ,
23932462 safe_chunks = safe_chunks ,
23942463 align_chunks = align_chunks ,
2464+ storage_options = storage_options ,
23952465 zarr_version = zarr_version ,
23962466 zarr_format = zarr_format ,
2397- write_empty = write_empty_chunks ,
2398- ** kwargs ,
2467+ write_empty_chunks = write_empty_chunks ,
23992468 )
24002469
2401- dataset = zstore ._validate_and_autodetect_region (
2402- dataset ,
2403- )
2470+ dataset = zstore ._validate_and_autodetect_region (dataset )
24042471 zstore ._validate_encoding (encoding )
24052472
24062473 writer = ArrayWriter ()
2407- # TODO: figure out how to properly handle unlimited_dims
2408- dump_to_store (dataset , zstore , writer , encoding = encoding )
2409- writes = writer .sync (
2410- compute = compute , chunkmanager_store_kwargs = chunkmanager_store_kwargs
2411- )
24122474
2413- if compute :
2414- _finalize_store (writes , zstore )
2415- else :
2416- import dask
2475+ # TODO: figure out how to properly handle unlimited_dims
2476+ try :
2477+ dump_to_store (dataset , zstore , writer , encoding = encoding )
2478+ writes = writer .sync (
2479+ compute = compute , chunkmanager_store_kwargs = chunkmanager_store_kwargs
2480+ )
2481+ finally :
2482+ if compute :
2483+ zstore .close ()
24172484
2418- return dask .delayed (_finalize_store )(writes , zstore )
2485+ if not compute :
2486+ return delayed_close_after_writes (writes , zstore )
24192487
24202488 return zstore
0 commit comments