From 86563b9bdbcfd40f0eaba7ce3684aef955fbc249 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 8 Aug 2025 15:00:06 -0700 Subject: [PATCH 01/37] Rewrite DataTree.to_netcdf and support netCDF4 in-memory This PR includes a handful of significant changes: 1. It refactors the internal structure of `DataTree.to_netcdf()` and `DataTree.to_zarr()` to use lower level interfaces, rather than calling `Dataset` methods. This allows for properly supporting `compute=False` (and likely various other improvements). 2. Reading and writing in-memory data with netCDF4-python is now supported, including DataTree. 3. The `engine` argument in `DataTree.to_netcdf()` is now set consistently with `Dataset.to_netcdf()`, preferring `netcdf4` to `h5netcdf`. 3. Calling `Dataset.to_netcdf()` without a target now always returns a `memoryview` object, *including* in the case where `engine='scipy'` is used (which currently returns `bytes`). This is a breaking change, rather than merely issuing a warning as is done in #10571. I believe it probably makes sense to do as a this breaking change because (1) it offers significant performance benefits, (2) the default behavior without specifying an engine will already change (because `netcdf4` is preferred to the `scipy` backend) and (3) restoring previous behavior is easy (by wrapping the memoryview with `bytes()`). mypy --- xarray/__init__.py | 2 + xarray/backends/api.py | 388 +++++++++++++-------- xarray/backends/common.py | 29 +- xarray/backends/file_manager.py | 99 +++++- xarray/backends/h5netcdf_.py | 23 +- xarray/backends/locks.py | 4 +- xarray/backends/netCDF4_.py | 96 ++++- xarray/backends/scipy_.py | 66 ++-- xarray/backends/zarr.py | 18 +- xarray/core/datatree.py | 6 +- xarray/core/datatree_io.py | 173 ++++++--- xarray/tests/__init__.py | 4 + xarray/tests/test_backends.py | 223 ++++++------ xarray/tests/test_backends_api.py | 4 +- xarray/tests/test_backends_datatree.py | 48 ++- xarray/tests/test_backends_file_manager.py | 18 +- 16 files changed, 796 insertions(+), 405 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index 04fb5b03867..7901fffcbed 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -4,6 +4,7 @@ from xarray.backends.api import ( load_dataarray, load_dataset, + load_datatree, open_dataarray, open_dataset, open_datatree, @@ -96,6 +97,7 @@ "infer_freq", "load_dataarray", "load_dataset", + "load_datatree", "map_blocks", "map_over_datasets", "merge", diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2a6476ea828..aaf6bae0d6a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -30,13 +30,14 @@ from xarray.backends import plugins from xarray.backends.common import ( AbstractDataStore, + AbstractWritableDataStore, ArrayWriter, BytesIOProxy, T_PathFileOrDataStore, _find_absolute_paths, _normalize_path, ) -from xarray.backends.locks import _get_scheduler +from xarray.backends.locks import get_dask_scheduler from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder from xarray.core import dtypes, indexing from xarray.core.coordinates import Coordinates @@ -132,30 +133,47 @@ def _get_default_engine_gz() -> Literal["scipy"]: return engine -def _get_default_engine_netcdf() -> Literal["netcdf4", "h5netcdf", "scipy"]: - candidates: list[tuple[str, str]] = [ - ("netcdf4", "netCDF4"), - ("h5netcdf", "h5netcdf"), - ("scipy", "scipy.io.netcdf"), - ] +def get_default_engine_netcdf( + format: T_NetcdfTypes | None, +) -> Literal["netcdf4", "h5netcdf", "scipy"]: + engines = { + "netcdf4": "netCDF4", + "scipy": "scipy.io.netcdf", + "h5netcdf": "h5netcdf", + } + + if format is None: + candidates = ["netcdf4", "h5netcdf", "scipy"] + elif format.upper().startswith("NETCDF3"): + candidates = ["netcdf4", "scipy"] + elif format.upper().startswith("NETCDF4"): + candidates = ["netcdf4", "h5netcdf"] + else: + raise AssertionError(f"unexpected {format=}") - for engine, module_name in candidates: + for engine in candidates: + module_name = engines[engine] if importlib.util.find_spec(module_name) is not None: return cast(Literal["netcdf4", "h5netcdf", "scipy"], engine) + format_str = f"with {format=}" if format is not None else "" raise ValueError( - "cannot read or write NetCDF files because none of " - "'netCDF4-python', 'h5netcdf', or 'scipy' are installed" + f"cannot read or write NetCDF files{format_str} because none of " + f"{set(candidates)} are installed" ) -def _get_default_engine(path: str, allow_remote: bool = False) -> T_NetcdfEngine: - if allow_remote and is_remote_uri(path): - return _get_default_engine_remote_uri() # type: ignore[return-value] - elif path.endswith(".gz"): - return _get_default_engine_gz() - else: - return _get_default_engine_netcdf() +def _get_default_engine( + path: str | None, + allow_remote: bool = False, + format: T_NetcdfTypes | None = None, +) -> T_NetcdfEngine: + if path is not None: + if allow_remote and is_remote_uri(path): + return _get_default_engine_remote_uri() # type: ignore[return-value] + if path.endswith(".gz"): + return _get_default_engine_gz() + return get_default_engine_netcdf(format) def _validate_dataset_names(dataset: Dataset) -> None: @@ -283,18 +301,24 @@ def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None: _protect_dataset_variables_inplace(node, cache) -def _finalize_store(write, store): +def _finalize_store(writes, store): """Finalize this store by explicitly syncing and closing""" - del write # ensure writing is done first + del writes # ensure writing is done first store.close() +def delayed_close_after_writes(writes, store): + import dask + + return dask.delayed(_finalize_store)(writes, store) + + def _multi_file_closer(closers): for closer in closers: closer() -def load_dataset(filename_or_obj, **kwargs) -> Dataset: +def load_dataset(filename_or_obj: T_PathFileOrDataStore, **kwargs) -> Dataset: """Open, load into memory, and close a Dataset from a file or file-like object. @@ -320,7 +344,7 @@ def load_dataset(filename_or_obj, **kwargs) -> Dataset: return ds.load() -def load_dataarray(filename_or_obj, **kwargs): +def load_dataarray(filename_or_obj: T_PathFileOrDataStore, **kwargs): """Open, load into memory, and close a DataArray from a file or file-like object containing a single data variable. @@ -346,6 +370,32 @@ def load_dataarray(filename_or_obj, **kwargs): return da.load() +def load_datatree(filename_or_obj: T_PathFileOrDataStore, **kwargs) -> DataTree: + """Open, load into memory, and close a DataTree from a file or file-like + object. + + This is a thin wrapper around :py:meth:`~xarray.open_datatree`. It differs + from `open_datatree` in that it loads the Dataset into memory, closes the + file, and returns the Dataset. In contrast, `open_datatree` keeps the file + handle open and lazy loads its contents. All parameters are passed directly + to `open_datatree`. See that documentation for further details. + + Returns + ------- + datatree : DataTree + The newly created DataTree. + + See Also + -------- + open_datatree + """ + if "cache" in kwargs: + raise TypeError("cache has no effect in this context") + + with open_datatree(filename_or_obj, **kwargs) as ds: + return ds.load() + + def _chunk_ds( backend_ds, filename_or_obj, @@ -512,14 +562,12 @@ def open_dataset( cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | Mapping[str, bool] | None = None, - decode_times: bool - | CFDatetimeCoder - | Mapping[str, bool | CFDatetimeCoder] - | None = None, - decode_timedelta: bool - | CFTimedeltaCoder - | Mapping[str, bool | CFTimedeltaCoder] - | None = None, + decode_times: ( + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] | None + ) = None, + decode_timedelta: ( + bool | CFTimedeltaCoder | Mapping[str, bool | CFTimedeltaCoder] | None + ) = None, use_cftime: bool | Mapping[str, bool] | None = None, concat_characters: bool | Mapping[str, bool] | None = None, decode_coords: Literal["coordinates", "all"] | bool | None = None, @@ -753,10 +801,9 @@ def open_dataarray( cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | None = None, - decode_times: bool - | CFDatetimeCoder - | Mapping[str, bool | CFDatetimeCoder] - | None = None, + decode_times: ( + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] | None + ) = None, decode_timedelta: bool | CFTimedeltaCoder | None = None, use_cftime: bool | None = None, concat_characters: bool | None = None, @@ -981,14 +1028,12 @@ def open_datatree( cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | Mapping[str, bool] | None = None, - decode_times: bool - | CFDatetimeCoder - | Mapping[str, bool | CFDatetimeCoder] - | None = None, - decode_timedelta: bool - | CFTimedeltaCoder - | Mapping[str, bool | CFTimedeltaCoder] - | None = None, + decode_times: ( + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] | None + ) = None, + decode_timedelta: ( + bool | CFTimedeltaCoder | Mapping[str, bool | CFTimedeltaCoder] | None + ) = None, use_cftime: bool | Mapping[str, bool] | None = None, concat_characters: bool | Mapping[str, bool] | None = None, decode_coords: Literal["coordinates", "all"] | bool | None = None, @@ -1221,14 +1266,12 @@ def open_groups( cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | Mapping[str, bool] | None = None, - decode_times: bool - | CFDatetimeCoder - | Mapping[str, bool | CFDatetimeCoder] - | None = None, - decode_timedelta: bool - | CFTimedeltaCoder - | Mapping[str, bool | CFTimedeltaCoder] - | None = None, + decode_times: ( + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] | None + ) = None, + decode_timedelta: ( + bool | CFTimedeltaCoder | Mapping[str, bool | CFTimedeltaCoder] | None + ) = None, use_cftime: bool | Mapping[str, bool] | None = None, concat_characters: bool | Mapping[str, bool] | None = None, decode_coords: Literal["coordinates", "all"] | bool | None = None, @@ -1460,10 +1503,9 @@ def open_groups( def open_mfdataset( - paths: str - | os.PathLike - | ReadBuffer - | NestedSequence[str | os.PathLike | ReadBuffer], + paths: ( + str | os.PathLike | ReadBuffer | NestedSequence[str | os.PathLike | ReadBuffer] + ), chunks: T_Chunks = None, concat_dim: ( str @@ -1477,10 +1519,9 @@ def open_mfdataset( compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT, preprocess: Callable[[Dataset], Dataset] | None = None, engine: T_Engine = None, - data_vars: Literal["all", "minimal", "different"] - | None - | list[str] - | CombineKwargDefault = _DATA_VARS_DEFAULT, + data_vars: ( + Literal["all", "minimal", "different"] | None | list[str] | CombineKwargDefault + ) = _DATA_VARS_DEFAULT, coords=_COORDS_DEFAULT, combine: Literal["by_coords", "nested"] = "by_coords", parallel: bool = False, @@ -1769,6 +1810,38 @@ def open_mfdataset( } +def get_writable_netcdf_store( + target, + engine: T_NetcdfEngine, + *, + format: T_NetcdfTypes | None, + mode: NetcdfWriteModes, + autoclose: bool, + invalid_netcdf: bool, + auto_complex: bool | None, +) -> AbstractWritableDataStore: + try: + store_open = WRITEABLE_STORES[engine] + except KeyError as err: + raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err + + if format is not None: + format = format.upper() # type: ignore[assignment] + + kwargs = dict(autoclose=True) if autoclose else {} + if invalid_netcdf: + if engine == "h5netcdf": + kwargs["invalid_netcdf"] = invalid_netcdf + else: + raise ValueError( + f"unrecognized option 'invalid_netcdf' for engine {engine}" + ) + if auto_complex is not None: + kwargs["auto_complex"] = auto_complex + + return store_open(target, mode=mode, format=format, **kwargs) + + # multifile=True returns writer and datastore @overload def to_netcdf( @@ -1926,42 +1999,28 @@ def to_netcdf( if encoding is None: encoding = {} - if isinstance(path_or_file, str): + if isinstance(path_or_file, str) or path_or_file is None: if engine is None: - engine = _get_default_engine(path_or_file) + engine = _get_default_engine(path_or_file, format=format) path_or_file = _normalize_path(path_or_file) - else: - # writing to bytes/memoryview or a file-like object - if engine is None: - # TODO: only use 'scipy' if format is None or a netCDF3 format - engine = "scipy" - elif engine not in ("scipy", "h5netcdf"): - raise ValueError( - "invalid engine for creating bytes/memoryview or writing to a " - f"file-like object with to_netcdf: {engine!r}. Only " - "engine=None, engine='scipy' and engine='h5netcdf' is " - "supported." - ) - if not compute: - raise NotImplementedError( - "to_netcdf() with compute=False is not yet implemented when " - "returning bytes" - ) + # writing to a file-like object + elif engine is None: + # TODO: only use 'scipy' if format is None or a netCDF3 format + engine = "scipy" + elif engine not in ("scipy", "h5netcdf"): + raise ValueError( + "invalid engine for creating bytes/memoryview or writing to a " + f"file-like object with to_netcdf: {engine!r}. Only " + "engine=None, engine='scipy' and engine='h5netcdf' is " + "supported." + ) # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) _validate_attrs(dataset, engine, invalid_netcdf) - try: - store_open = WRITEABLE_STORES[engine] - except KeyError as err: - raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err - - if format is not None: - format = format.upper() # type: ignore[assignment] - # handle scheduler specific logic - scheduler = _get_scheduler() + scheduler = get_dask_scheduler() have_chunks = any(v.chunks is not None for v in dataset.variables.values()) autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] @@ -1972,30 +2031,26 @@ def to_netcdf( ) if path_or_file is None: + if not compute: + raise NotImplementedError( + "to_netcdf() with compute=False is not yet implemented when " + "returning a memoryview" + ) target = BytesIOProxy() else: target = path_or_file # type: ignore[assignment] - kwargs = dict(autoclose=True) if autoclose else {} - if invalid_netcdf: - if engine == "h5netcdf": - kwargs["invalid_netcdf"] = invalid_netcdf - else: - raise ValueError( - f"unrecognized option 'invalid_netcdf' for engine {engine}" - ) - if auto_complex is not None: - kwargs["auto_complex"] = auto_complex - - store = store_open(target, mode, format, group, **kwargs) - - if unlimited_dims is None: - unlimited_dims = dataset.encoding.get("unlimited_dims", None) - if unlimited_dims is not None: - if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): - unlimited_dims = [unlimited_dims] - else: - unlimited_dims = list(unlimited_dims) + store = get_writable_netcdf_store( + target, + engine, + mode=mode, + format=format, + autoclose=autoclose, + invalid_netcdf=invalid_netcdf, + auto_complex=auto_complex, + ) + if group is not None: + store = store.get_child_store(group) writer = ArrayWriter() @@ -2021,12 +2076,10 @@ def to_netcdf( if path_or_file is None: assert isinstance(target, BytesIOProxy) # created in this function - return target.getvalue_or_getbuffer() + return target.getbuffer() if not compute: - import dask - - return dask.delayed(_finalize_store)(writes, store) + return delayed_close_after_writes(writes, store) return None @@ -2041,6 +2094,15 @@ def dump_to_store( if encoding is None: encoding = {} + if unlimited_dims is None: + unlimited_dims = dataset.encoding.get("unlimited_dims", None) + + if unlimited_dims is not None: + if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): + unlimited_dims = [unlimited_dims] + else: + unlimited_dims = list(unlimited_dims) + variables, attrs = conventions.encode_dataset_coordinates(dataset) check_encoding = set() @@ -2190,12 +2252,60 @@ def save_mfdataset( import dask return dask.delayed( - list( - starmap(dask.delayed(_finalize_store), zip(writes, stores, strict=True)) - ) + list(starmap(delayed_close_after_writes, zip(writes, stores, strict=True))) ) +def get_writable_zarr_store( + store: ZarrStoreLike | None = None, + *, + chunk_store: MutableMapping | str | os.PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + align_chunks: bool = False, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + zarr_format: int | None = None, + write_empty_chunks: bool | None = None, +) -> backends.ZarrStore: + from xarray.backends.zarr import _choose_default_mode, _get_mappers + + kwargs, mapper, chunk_mapper = _get_mappers( + storage_options=storage_options, store=store, chunk_store=chunk_store + ) + mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region) + + if mode == "r+": + already_consolidated = consolidated + consolidate_on_close = False + else: + already_consolidated = False + consolidate_on_close = consolidated or consolidated is None + + return backends.ZarrStore.open_group( + store=mapper, + mode=mode, + synchronizer=synchronizer, + group=group, + consolidated=already_consolidated, + consolidate_on_close=consolidate_on_close, + chunk_store=chunk_mapper, + append_dim=append_dim, + write_region=region, + safe_chunks=safe_chunks, + align_chunks=align_chunks, + zarr_version=zarr_version, + zarr_format=zarr_format, + write_empty=write_empty_chunks, + **kwargs, + ) + + # compute=True returns ZarrStore @overload def to_zarr( @@ -2270,8 +2380,6 @@ def to_zarr( See `Dataset.to_zarr` for full API docs. """ - from xarray.backends.zarr import _choose_default_mode, _get_mappers - # validate Dataset keys, DataArray names _validate_dataset_names(dataset) @@ -2285,53 +2393,39 @@ def to_zarr( if encoding is None: encoding = {} - kwargs, mapper, chunk_mapper = _get_mappers( - storage_options=storage_options, store=store, chunk_store=chunk_store - ) - mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region) - - if mode == "r+": - already_consolidated = consolidated - consolidate_on_close = False - else: - already_consolidated = False - consolidate_on_close = consolidated or consolidated is None - - zstore = backends.ZarrStore.open_group( - store=mapper, + zstore = get_writable_zarr_store( + store, + chunk_store=chunk_store, mode=mode, synchronizer=synchronizer, group=group, - consolidated=already_consolidated, - consolidate_on_close=consolidate_on_close, - chunk_store=chunk_mapper, + consolidated=consolidated, append_dim=append_dim, - write_region=region, + region=region, safe_chunks=safe_chunks, align_chunks=align_chunks, + storage_options=storage_options, zarr_version=zarr_version, zarr_format=zarr_format, - write_empty=write_empty_chunks, - **kwargs, + write_empty_chunks=write_empty_chunks, ) - dataset = zstore._validate_and_autodetect_region( - dataset, - ) + dataset = zstore._validate_and_autodetect_region(dataset) zstore._validate_encoding(encoding) writer = ArrayWriter() - # TODO: figure out how to properly handle unlimited_dims - dump_to_store(dataset, zstore, writer, encoding=encoding) - writes = writer.sync( - compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs - ) - if compute: - _finalize_store(writes, zstore) - else: - import dask + # TODO: figure out how to properly handle unlimited_dims + try: + dump_to_store(dataset, zstore, writer, encoding=encoding) + writes = writer.sync( + compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs + ) + finally: + if compute: + zstore.close() - return dask.delayed(_finalize_store)(writes, zstore) + if not compute: + return delayed_close_after_writes(writes, zstore) return zstore diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 542ca4c897b..a0acafb7622 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -11,7 +11,7 @@ TYPE_CHECKING, Any, ClassVar, - Generic, + Self, TypeVar, Union, overload, @@ -197,22 +197,19 @@ def _normalize_path_list( return _normalize_path_list(paths) -BytesOrMemory = TypeVar("BytesOrMemory", bytes, memoryview) - - @dataclass -class BytesIOProxy(Generic[BytesOrMemory]): - """Proxy object for a write that returns either bytes or a memoryview.""" +class BytesIOProxy: + """Proxy object for a write that returns a memoryview.""" - # TODO: remove this in favor of BytesIO when Dataset.to_netcdf() stops - # returning bytes from the scipy engine - getvalue: Callable[[], BytesOrMemory] | None = None + getter: Callable[[], memoryview] | None = None - def getvalue_or_getbuffer(self) -> BytesOrMemory: - """Get the value of this write as bytes or memory.""" - if self.getvalue is None: - raise ValueError("must set getvalue before fetching value") - return self.getvalue() + # TODO: rename this to getbfuffer() when Dataset.to_netcdf() stops returning + # bytes from the scipy engine + def getbuffer(self) -> memoryview: + """Get the value of this write a memoryview.""" + if self.getter is None: + raise ValueError("must set getter before fetching value") + return self.getter() def _open_remote_file(file, mode, storage_options=None): @@ -305,6 +302,10 @@ def get_duck_array(self, dtype: np.typing.DTypeLike = None): class AbstractDataStore: __slots__ = () + def get_child_store(self, group: str) -> Self: # pragma: no cover + """Get a store corresponding to the indicated child group.""" + raise NotImplementedError() + def get_dimensions(self): # pragma: no cover raise NotImplementedError() diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 2a6f3691faf..bc8c199a913 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -6,7 +6,7 @@ import threading import uuid import warnings -from collections.abc import Hashable +from collections.abc import Hashable, Mapping from typing import Any from xarray.backends.locks import acquire @@ -82,9 +82,9 @@ class CachingFileManager(FileManager): def __init__( self, opener, - *args, - mode=_DEFAULT_MODE, - kwargs=None, + *args: Any, + mode: Any = _DEFAULT_MODE, + kwargs: Mapping[str, Any] | None = None, lock=None, cache=None, manager_id: Hashable | None = None, @@ -290,13 +290,6 @@ def __repr__(self) -> str: ) -@atexit.register -def _remove_del_method(): - # We don't need to close unclosed files at program exit, and may not be able - # to, because Python is cleaning up imports / globals. - del CachingFileManager.__del__ - - class _RefCounter: """Class for keeping track of reference counts.""" @@ -336,6 +329,90 @@ def __hash__(self): return self.hashvalue +class PickleableFileManager(FileManager): + """File manager that supports pickling by reopening a file object. + + Use PickleableFileManager for wrapping file-like objects that do not natively + support pickling (e.g., netCDF4.Dataset and h5netcdf.File) in cases where a + global cache is not desirable (e.g., for netCDF files opened from bytes in + memory, or from existing file objects). + """ + + def __init__( + self, + opener, + *args, + mode=_DEFAULT_MODE, + kwargs=None, + ): + kwargs = {} if kwargs is None else dict(kwargs) + self._opener = opener + self._args = args + self._mode = "a" if mode == "w" else mode + self._kwargs = kwargs + + # Note: No need for locking with PickleableFileManager, because all + # opening of files happens in the constructor. + if mode is not _DEFAULT_MODE: + kwargs = kwargs.copy() + kwargs["mode"] = mode + self._file = opener(*args, **kwargs) + self._closed = False + + def acquire(self, needs_lock=True): + return self._file + + @contextlib.contextmanager + def acquire_context(self, needs_lock=True): + yield self._file + + def close(self, needs_lock=True): + if not self._closed: + self._file.close() + self._closed = True + + def __del__(self) -> None: + # If opener() raised an error in the constructor, _closed may not be set + if not getattr(self, "_closed", True): + self.close() + + if OPTIONS["warn_for_unclosed_files"]: + warnings.warn( + f"deallocating {self}, but file is not already closed. " + "This may indicate a bug.", + RuntimeWarning, + stacklevel=2, + ) + + def __getstate__(self): + # file is intentionally omitted: we want to open it again + return (self._opener, self._args, self._mode, self._kwargs) + + def __setstate__(self, state) -> None: + opener, args, mode, kwargs = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs) # type: ignore[misc] + + def __repr__(self) -> str: + args_string = ", ".join(map(repr, self._args)) + if self._mode is not _DEFAULT_MODE: + args_string += f", mode={self._mode!r}" + if "memory" in self._kwargs: + kwargs = self._kwargs | {"memory": utils.ReprObject("...")} + else: + kwargs = self._kwargs + return ( + f"{type(self).__name__}({self._opener!r}, {args_string}, kwargs={kwargs})" + ) + + +@atexit.register +def _remove_del_methods(): + # We don't need to close unclosed files at program exit, and may not be able + # to, because Python is cleaning up imports / globals. + del CachingFileManager.__del__ + del PickleableFileManager.__del__ + + class DummyFileManager(FileManager): """FileManager that simply wraps an open file in the FileManager interface.""" diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 24a3324bf62..ae4351c989a 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -4,7 +4,7 @@ import io import os from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Self import numpy as np @@ -23,6 +23,7 @@ CachingFileManager, DummyFileManager, FileManager, + PickleableFileManager, ) from xarray.backends.locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from xarray.backends.netCDF4_ import ( @@ -149,6 +150,17 @@ def __init__( self.lock = ensure_lock(lock) self.autoclose = autoclose + def get_child_store(self, group: str) -> Self: + if self._group is not None: + group = os.path.join(self._group, group) + return type(self)( + self._manager, + group=group, + mode=self._mode, + lock=self.lock, + autoclose=self.autoclose, + ) + @classmethod def open( cls, @@ -176,7 +188,7 @@ def open( if isinstance(filename, BytesIOProxy): source = filename filename = io.BytesIO() - source.getvalue = filename.getbuffer + source.getter = filename.getbuffer if isinstance(filename, io.IOBase) and mode == "r": magic_number = read_magic_number_from_file(filename) @@ -204,11 +216,10 @@ def open( else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) - manager = ( - CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) - if isinstance(filename, str) - else h5netcdf.File(filename, mode=mode, **kwargs) + manager_cls = ( + CachingFileManager if isinstance(filename, str) else PickleableFileManager ) + manager = manager_cls(h5netcdf.File, filename, mode=mode, kwargs=kwargs) return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index c6a06dd714e..82d3e0b7dae 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -134,7 +134,7 @@ def _get_lock_maker(scheduler=None): raise KeyError(scheduler) -def _get_scheduler(get=None, collection=None) -> str | None: +def get_dask_scheduler(get=None, collection=None) -> str | None: """Determine the dask scheduler that is being used. None is returned if no dask scheduler is active. @@ -184,7 +184,7 @@ def get_write_lock(key): ------- Lock object that can be used like a threading.Lock object. """ - scheduler = _get_scheduler() + scheduler = get_dask_scheduler() lock_maker = _get_lock_maker(scheduler) return lock_maker(key) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index ab1841461f4..c3b97d03070 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -5,7 +5,8 @@ import os from collections.abc import Iterable from contextlib import suppress -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Self import numpy as np @@ -13,6 +14,7 @@ BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, + BytesIOProxy, T_PathFileOrDataStore, WritableCFDataStore, _normalize_path, @@ -20,7 +22,11 @@ find_root_and_group, robust_getitem, ) -from xarray.backends.file_manager import CachingFileManager, DummyFileManager +from xarray.backends.file_manager import ( + CachingFileManager, + DummyFileManager, + PickleableFileManager, +) from xarray.backends.locks import ( HDF5_LOCK, NETCDFC_LOCK, @@ -47,6 +53,7 @@ from xarray.core.variable import Variable if TYPE_CHECKING: + import netCDF4 from h5netcdf.core import EnumType as h5EnumType from netCDF4 import EnumType as ncEnumType @@ -357,6 +364,26 @@ def _build_and_get_enum( return datatype +@dataclass +class _Thunk: + value: Any + + def __call__(self): + return self.value + + +@dataclass +class _CloseWithCopy: + """Wrapper around netCDF4's esoteric interface for writing in-memory data.""" + + proxy: BytesIOProxy + nc4_dataset: netCDF4.Dataset + + def __call__(self): + value = self.nc4_dataset.close() + self.proxy.getter = _Thunk(value) + + class NetCDF4DataStore(WritableCFDataStore): """Store for reading and writing data via the Python-NetCDF4 library. @@ -400,6 +427,17 @@ def __init__( self.lock = ensure_lock(lock) self.autoclose = autoclose + def get_child_store(self, group: str) -> Self: + if self._group is not None: + group = os.path.join(self._group, group) + return type(self)( + self._manager, + group=group, + mode=self._mode, + lock=self.lock, + autoclose=self.autoclose, + ) + @classmethod def open( cls, @@ -420,10 +458,11 @@ def open( if isinstance(filename, os.PathLike): filename = os.fspath(filename) - if not isinstance(filename, str): - raise ValueError( - "can only read bytes or file-like objects " - "with engine='scipy' or 'h5netcdf'" + if not isinstance(filename, str | bytes | memoryview | BytesIOProxy): + raise TypeError( + f"invalid filename for netCDF4 backend: {filename}" + # "can only read bytes or file-like objects " + # "with engine='scipy' or 'h5netcdf'" ) if format is None: @@ -431,16 +470,18 @@ def open( if lock is None: if mode == "r": - if is_remote_uri(filename): + if isinstance(filename, str) and is_remote_uri(filename): lock = NETCDFC_LOCK else: lock = NETCDF4_PYTHON_LOCK else: if format is None or format.startswith("NETCDF4"): - base_lock = NETCDF4_PYTHON_LOCK + lock = NETCDF4_PYTHON_LOCK else: - base_lock = NETCDFC_LOCK - lock = combine_locks([base_lock, get_write_lock(filename)]) + lock = NETCDFC_LOCK + + if isinstance(filename, str): + lock = combine_locks([lock, get_write_lock(filename)]) kwargs = dict( clobber=clobber, @@ -450,9 +491,31 @@ def open( ) if auto_complex is not None: kwargs["auto_complex"] = auto_complex - manager = CachingFileManager( - netCDF4.Dataset, filename, mode=mode, kwargs=kwargs - ) + + if isinstance(filename, BytesIOProxy): + assert mode == "w" + # Size hint used for creating netCDF3 files. Per the documentation + # for nc__create(), the special value NC_SIZEHINT_DEFAULT (which is + # the value 0), lets the netcdf library choose a suitable initial + # size. + memory = 0 + kwargs["diskless"] = False + nc4_dataset = netCDF4.Dataset( + "", mode=mode, memory=memory, **kwargs + ) + close = _CloseWithCopy(filename, nc4_dataset) + manager = DummyFileManager(nc4_dataset, close=close) + + elif isinstance(filename, bytes | memoryview): + assert mode == "r" + kwargs["memory"] = filename + manager = PickleableFileManager( + netCDF4.Dataset, "", mode=mode, kwargs=kwargs + ) + else: + manager = CachingFileManager( + netCDF4.Dataset, filename, mode=mode, kwargs=kwargs + ) return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): @@ -631,7 +694,12 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): return True - magic_number = try_read_magic_number_from_path(filename_or_obj) + + magic_number = ( + bytes(filename_or_obj[:8]) + if isinstance(filename_or_obj, bytes | memoryview) + else try_read_magic_number_from_path(filename_or_obj) + ) if magic_number is not None: # netcdf 3 or HDF5 return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n")) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index a93c6465d49..21b733cf618 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -30,12 +30,17 @@ Frozen, FrozenDict, close_on_error, - emit_user_level_warning, module_available, try_read_magic_number_from_file_or_path, ) from xarray.core.variable import Variable +try: + from scipy.io import netcdf_file as netcdf_file_base +except ImportError: + netcdf_file_base = object + + if TYPE_CHECKING: import scipy.io @@ -105,13 +110,35 @@ def __setitem__(self, key, value): raise -def _open_scipy_netcdf(filename, mode, mmap, version): +class flush_only_netcdf_file(netcdf_file_base): + # scipy.io.netcdf_file.close() incorrectly closes file objects that + # were passed in as constructor arguments: + # https://github.com/scipy/scipy/issues/13905 + + # Instead of closing such files, only call flush(), which is + # equivalent as long as the netcdf_file object is not mmapped. + # This suffices to keep BytesIO objects open long enough to read + # their contents from to_netcdf(), but underlying files still get + # closed when the netcdf_file is garbage collected (via __del__), + # and will need to be fixed upstream in scipy. + def close(self): + self.flush() + + def __del__(self): + # Remove the __del__ method. These files need to be closed explicitly by + # xarray. + pass + + +def _open_scipy_netcdf(filename, mode, mmap, version, flush_only=False): import scipy.io + netcdf_file = flush_only_netcdf_file if flush_only else scipy.io.netcdf_file + # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: - return scipy.io.netcdf_file( + return netcdf_file( gzip.open(filename), mode=mode, mmap=mmap, version=version ) except TypeError as e: @@ -125,7 +152,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): raise try: - return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version) + return netcdf_file(filename, mode=mode, mmap=mmap, version=version) except TypeError as e: # netcdf3 message is obscure in this case errmsg = e.args[0] if "is not a valid NetCDF 3 file" in errmsg: @@ -169,20 +196,9 @@ def __init__( self.lock = ensure_lock(lock) if isinstance(filename_or_obj, BytesIOProxy): - emit_user_level_warning( - "return value of to_netcdf() without a target for " - "engine='scipy' is currently bytes, but will switch to " - "memoryview in a future version of Xarray. To silence this " - "warning, use the following pattern or switch to " - "to_netcdf(engine='h5netcdf'):\n" - " target = io.BytesIO()\n" - " dataset.to_netcdf(target)\n" - " result = target.getbuffer()", - FutureWarning, - ) source = filename_or_obj filename_or_obj = io.BytesIO() - source.getvalue = filename_or_obj.getvalue + source.getter = filename_or_obj.getbuffer if isinstance(filename_or_obj, str): # path manager = CachingFileManager( @@ -195,20 +211,16 @@ def __init__( elif hasattr(filename_or_obj, "seek"): # file object # Note: checking for .seek matches the check for file objects # in scipy.io.netcdf_file + flush_only = mode in "wa" scipy_dataset = _open_scipy_netcdf( - filename_or_obj, mode=mode, mmap=mmap, version=version + filename_or_obj, + mode=mode, + mmap=mmap, + version=version, + flush_only=flush_only, ) - # scipy.io.netcdf_file.close() incorrectly closes file objects that - # were passed in as constructor arguments: - # https://github.com/scipy/scipy/issues/13905 - # Instead of closing such files, only call flush(), which is - # equivalent as long as the netcdf_file object is not mmapped. - # This suffices to keep BytesIO objects open long enough to read - # their contents from to_netcdf(), but underlying files still get - # closed when the netcdf_file is garbage collected (via __del__), - # and will need to be fixed upstream in scipy. assert not scipy_dataset.use_mmap # no mmap for file objects - manager = DummyFileManager(scipy_dataset, close=scipy_dataset.flush) + manager = DummyFileManager(scipy_dataset) else: raise ValueError( f"cannot open {filename_or_obj=} with scipy.io.netcdf_file" diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 1b62a87d10c..1f48254e861 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -5,7 +5,7 @@ import os import struct from collections.abc import Hashable, Iterable, Mapping -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, Self, cast import numpy as np import pandas as pd @@ -735,6 +735,22 @@ def __init__( # on demand. self._members = self._fetch_members() + def get_child_store(self, group: str) -> Self: + zarr_group = self.zarr_group.require_group(group) + return type(self)( + zarr_group=zarr_group, + mode=self._mode, + consolidate_on_close=self._consolidate_on_close, + append_dim=self._append_dim, + write_region=self._write_region, + safe_chunks=self._safe_chunks, + write_empty=self._write_empty, + close_store_on_close=self._close_store_on_close, + use_zarr_fill_value_as_mask=self._use_zarr_fill_value_as_mask, + align_chunks=self._align_chunks, + cache_members=self._cache_members, + ) + @property def members(self) -> dict[str, ZarrArray | ZarrGroup]: """ diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bf82baccb31..332c2d16001 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -75,7 +75,9 @@ if TYPE_CHECKING: import numpy as np import pandas as pd + from dask.delayed import Delayed + from xarray.backends import ZarrStore from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.types import ( Dims, @@ -1788,7 +1790,7 @@ def to_zarr( write_inherited_coords: bool = False, compute: bool = True, **kwargs, - ): + ) -> ZarrStore | Delayed: """ Write datatree contents to a Zarr store. @@ -1831,7 +1833,7 @@ def to_zarr( """ from xarray.core.datatree_io import _datatree_to_zarr - _datatree_to_zarr( + return _datatree_to_zarr( self, store, mode=mode, diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index c586caaba89..522b791c183 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -1,10 +1,19 @@ from __future__ import annotations import io -from collections.abc import Mapping +from collections.abc import Hashable, Mapping, MutableMapping from os import PathLike from typing import TYPE_CHECKING, Any, Literal, get_args +from xarray.backends.api import ( + delayed_close_after_writes, + dump_to_store, + get_default_engine_netcdf, + get_writable_netcdf_store, + get_writable_zarr_store, +) +from xarray.backends.common import ArrayWriter, BytesIOProxy +from xarray.backends.locks import get_dask_scheduler from xarray.core.datatree import DataTree from xarray.core.types import NetcdfWriteModes, ZarrWriteModes @@ -12,6 +21,9 @@ T_DataTreeNetcdfTypes = Literal["NETCDF4"] if TYPE_CHECKING: + from dask.delayed import Delayed + + from xarray.backends import ZarrStore from xarray.core.types import ZarrStoreLike @@ -26,7 +38,8 @@ def _datatree_to_netcdf( group: str | None = None, write_inherited_coords: bool = False, compute: bool = True, - **kwargs, + invalid_netcdf: bool = False, + auto_complex: bool | None = None, ) -> None | memoryview: """Implementation of `DataTree.to_netcdf`.""" @@ -39,7 +52,7 @@ def _datatree_to_netcdf( ) if engine is None: - engine = "h5netcdf" + engine = get_default_engine_netcdf(format="NETCDF4") # type: ignore[assignment] if group is not None: raise NotImplementedError( @@ -61,36 +74,67 @@ def _datatree_to_netcdf( ) if filepath is None: - # No need to use BytesIOProxy here because the legacy scipy backend - # cannot write netCDF files with groups - target = io.BytesIO() + target = BytesIOProxy() else: target = filepath # type: ignore[assignment] if unlimited_dims is None: unlimited_dims = {} - for node in dt.subtree: - at_root = node is dt - ds = node.to_dataset(inherit=write_inherited_coords or at_root) - group_path = None if at_root else "/" + node.relative_to(dt) - ds.to_netcdf( - target, - group=group_path, - mode=mode, - encoding=encoding.get(node.path), - unlimited_dims=unlimited_dims.get(node.path), - engine=engine, - format=format, - compute=compute, - **kwargs, - ) - mode = "a" + scheduler = get_dask_scheduler() + have_chunks = any( + v.chunks is not None for node in dt.subtree for v in node.variables.values() + ) + autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] + + root_store = get_writable_netcdf_store( + target, + engine, # type: ignore[arg-type] + mode=mode, + format=format, + autoclose=autoclose, + invalid_netcdf=invalid_netcdf, + auto_complex=auto_complex, + ) + if group is not None: + root_store = root_store.get_child(group) + + writer = ArrayWriter() + + try: + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + + for node in dt.subtree: + at_root = node is dt + dataset = node.to_dataset(inherit=write_inherited_coords or at_root) + node_store = ( + root_store if at_root else root_store.get_child_store(node.path) + ) + dump_to_store( + dataset, + node_store, + writer, + encoding=encoding.get(node.path), + unlimited_dims=unlimited_dims.get(node.path), + ) + + if autoclose: + root_store.close() + + writes = writer.sync(compute=compute) + + finally: + if compute: + root_store.close() if filepath is None: - assert isinstance(target, io.BytesIO) + assert isinstance(target, BytesIOProxy) return target.getbuffer() + if not compute: + return delayed_close_after_writes(writes, root_store) + return None @@ -99,22 +143,31 @@ def _datatree_to_zarr( store: ZarrStoreLike, mode: ZarrWriteModes = "w-", encoding: Mapping[str, Any] | None = None, - consolidated: bool = True, + synchronizer=None, group: str | None = None, write_inherited_coords: bool = False, + *, + chunk_store: MutableMapping | str | PathLike | None = None, compute: bool = True, - **kwargs, -): + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + align_chunks: bool = False, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + zarr_format: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, +) -> ZarrStore | Delayed: """Implementation of `DataTree.to_zarr`.""" - from zarr import consolidate_metadata - if group is not None: raise NotImplementedError( "specifying a root group for the tree has not been implemented" ) - if "append_dim" in kwargs: + if append_dim is not None: raise NotImplementedError( "specifying ``append_dim`` with ``DataTree.to_zarr`` has not been implemented" ) @@ -130,21 +183,51 @@ def _datatree_to_zarr( f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" ) - for node in dt.subtree: - at_root = node is dt - ds = node.to_dataset(inherit=write_inherited_coords or at_root) - group_path = None if at_root else "/" + node.relative_to(dt) - ds.to_zarr( - store, - group=group_path, - mode=mode, - encoding=encoding.get(node.path), - consolidated=False, - compute=compute, - **kwargs, + root_store = get_writable_zarr_store( + store, + chunk_store=chunk_store, + mode=mode, + synchronizer=synchronizer, + group=group, + consolidated=consolidated, + append_dim=append_dim, + region=region, + safe_chunks=safe_chunks, + align_chunks=align_chunks, + storage_options=storage_options, + zarr_version=zarr_version, + zarr_format=zarr_format, + write_empty_chunks=write_empty_chunks, + ) + + writer = ArrayWriter() + + # TODO: figure out how to properly handle unlimited_dims + try: + for node in dt.subtree: + at_root = node is dt + dataset = node.to_dataset(inherit=write_inherited_coords or at_root) + node_store = ( + root_store if at_root else root_store.get_child_store(node.path) + ) + + dataset = node_store._validate_and_autodetect_region(dataset) + node_store._validate_encoding(encoding) + + dump_to_store( + dataset, + node_store, + writer, + encoding=encoding.get(node.path), + ) + writes = writer.sync( + compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs ) - if "w" in mode: - mode = "a" + finally: + if compute: + root_store.close() + + if not compute: + return delayed_close_after_writes(writes, root_store) - if consolidated: - consolidate_metadata(store) + return root_store diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 787c01eaf62..4cdf89a76f9 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -173,6 +173,10 @@ def _importorskip( requires_scipy_or_netCDF4 = pytest.mark.skipif( not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" ) +has_h5netcdf_or_netCDF4 = has_h5netcdf or has_netCDF4 +requires_h5netcdf_or_netCDF4 = pytest.mark.skipif( + not has_h5netcdf_or_netCDF4, reason="requires h5netcdf or netCDF4" +) has_numbagg_or_bottleneck = has_numbagg or has_bottleneck requires_numbagg_or_bottleneck = pytest.mark.skipif( not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck" diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 2ff73203580..69bc67bb6c0 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -82,6 +82,7 @@ requires_fsspec, requires_h5netcdf, requires_h5netcdf_1_4_0_or_above, + requires_h5netcdf_or_netCDF4, requires_h5netcdf_ros3, requires_iris, requires_netcdf, @@ -2296,6 +2297,79 @@ def test_deepcopy(self) -> None: assert_identical(expected, copied) +class InMemoryNetCDF: + engine: T_NetcdfEngine | None + + def test_roundtrip(self) -> None: + original = create_test_data() + result = original.to_netcdf(engine=self.engine) + roundtrip = load_dataset(result, engine=self.engine) + assert_identical(roundtrip, original) + + def test_roundtrip_via_memoryview(self) -> None: + original = create_test_data() + result = memoryview(original.to_netcdf(engine=self.engine)) + roundtrip = load_dataset(result, engine=self.engine) + assert_identical(roundtrip, original) + + def test_roundtrip_via_bytes(self) -> None: + original = create_test_data() + result = bytes(original.to_netcdf(engine=self.engine)) + roundtrip = load_dataset(result, engine=self.engine) + assert_identical(roundtrip, original) + + def test_pickle_open_dataset_from_bytes(self) -> None: + original = Dataset({"foo": ("x", [1, 2, 3])}) + netcdf_bytes = bytes(original.to_netcdf(engine=self.engine)) + with open_dataset(netcdf_bytes, engine=self.engine) as roundtrip: + unpickled = pickle.loads(pickle.dumps(roundtrip)) + assert_identical(unpickled, original) + unpickled.close() + + +class InMemoryNetCDFWithGroups(InMemoryNetCDF): + def test_roundtrip_group_via_memoryview(self) -> None: + original = create_test_data() + netcdf_bytes = original.to_netcdf(group="sub", engine=self.engine) + roundtrip = load_dataset(netcdf_bytes, group="sub", engine=self.engine) + assert_identical(roundtrip, original) + + +class FileObjectNetCDF: + engine: T_NetcdfEngine + + def test_open_twice(self) -> None: + expected = create_test_data() + expected.attrs["foo"] = "bar" + with create_tmp_file() as tmp_file: + expected.to_netcdf(tmp_file, engine=self.engine) + with open(tmp_file, "rb") as f: + with open_dataset(f, engine=self.engine): + with open_dataset(f, engine=self.engine): + pass + + def test_file_remains_open(self) -> None: + data = Dataset({"foo": ("x", [1, 2, 3])}) + f = BytesIO() + data.to_netcdf(f, engine=self.engine) + assert not f.closed + restored = open_dataset(f, engine=self.engine) + assert not f.closed + assert_identical(restored, data) + restored.close() + assert not f.closed + + +@requires_h5netcdf_or_netCDF4 +class TestGenericNetCDF4InMemory(InMemoryNetCDFWithGroups): + engine = None + + +@requires_netCDF4 +class TestNetCDF4InMemory(InMemoryNetCDFWithGroups): + engine: T_NetcdfEngine = "netcdf4" + + @requires_netCDF4 @requires_dask @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") @@ -4048,7 +4122,7 @@ def test_zarr_version_deprecated() -> None: @requires_scipy -class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only): +class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only, InMemoryNetCDF): engine: T_NetcdfEngine = "scipy" @contextlib.contextmanager @@ -4056,37 +4130,21 @@ def create_store(self): fobj = BytesIO() yield backends.ScipyDataStore(fobj, "w") - def test_to_netcdf_explicit_engine(self) -> None: - with pytest.warns( - FutureWarning, - match=re.escape("return value of to_netcdf() without a target"), - ): - Dataset({"foo": 42}).to_netcdf(engine="scipy") - - def test_roundtrip_via_bytes(self) -> None: - original = create_test_data() - with pytest.warns( - FutureWarning, - match=re.escape("return value of to_netcdf() without a target"), - ): - netcdf_bytes = original.to_netcdf(engine="scipy") - roundtrip = open_dataset(netcdf_bytes, engine="scipy") - assert_identical(roundtrip, original) - - def test_bytes_pickle(self) -> None: - data = Dataset({"foo": ("x", [1, 2, 3])}) - with pytest.warns( - FutureWarning, - match=re.escape("return value of to_netcdf() without a target"), - ): - fobj = data.to_netcdf() - with self.open(fobj) as ds: - unpickled = pickle.loads(pickle.dumps(ds)) - assert_identical(unpickled, data) + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + saved = self.save(data, path=None, **save_kwargs) + with self.open(saved, **open_kwargs) as ds: + yield ds @requires_scipy -class TestScipyFileObject(CFEncodedBase, NetCDF3Only): +class TestScipyFileObject(CFEncodedBase, NetCDF3Only, FileObjectNetCDF): # TODO: Consider consolidating some of these cases (e.g., # test_file_remains_open) with TestH5NetCDFFileObject engine: T_NetcdfEngine = "scipy" @@ -4111,19 +4169,15 @@ def roundtrip( with self.open(f, **open_kwargs) as ds: yield ds + @pytest.mark.xfail(reason="not working yet") + def test_open_twice(self): + super().test_open_twice() + @pytest.mark.xfail( reason="scipy.io.netcdf_file closes files upon garbage collection" ) def test_file_remains_open(self) -> None: - data = Dataset({"foo": ("x", [1, 2, 3])}) - f = BytesIO() - data.to_netcdf(f, engine="scipy") - assert not f.closed - restored = open_dataset(f, engine="scipy") - assert not f.closed - assert_identical(restored, data) - restored.close() - assert not f.closed + super().test_file_remains_open() @pytest.mark.skip(reason="cannot pickle file objects") def test_pickle(self) -> None: @@ -4241,8 +4295,6 @@ def test_engine(self) -> None: data = create_test_data() with pytest.raises(ValueError, match=r"unrecognized engine"): data.to_netcdf("foo.nc", engine="foobar") # type: ignore[call-overload] - with pytest.raises(ValueError, match=r"invalid engine"): - data.to_netcdf(engine="netcdf4") with create_tmp_file() as tmp_file: data.to_netcdf(tmp_file) @@ -4299,32 +4351,6 @@ def test_encoding_unlimited_dims(self) -> None: assert actual.encoding["unlimited_dims"] == set("y") assert_equal(ds, actual) - @requires_scipy - def test_roundtrip_via_bytes(self) -> None: - original = create_test_data() - with pytest.warns( - FutureWarning, - match=re.escape("return value of to_netcdf() without a target"), - ): - netcdf_bytes = original.to_netcdf() - roundtrip = open_dataset(netcdf_bytes) - assert_identical(roundtrip, original) - - @pytest.mark.xfail( - reason="scipy.io.netcdf_file closes files upon garbage collection" - ) - @requires_scipy - def test_roundtrip_via_file_object(self) -> None: - original = create_test_data() - f = BytesIO() - original.to_netcdf(f) - assert not f.closed - restored = open_dataset(f) - assert not f.closed - assert_identical(restored, original) - restored.close() - assert not f.closed - @requires_h5netcdf @requires_netCDF4 @@ -4600,7 +4626,7 @@ def test_deepcopy(self) -> None: @requires_h5netcdf -class TestH5NetCDFFileObject(TestH5NetCDFData): +class TestH5NetCDFFileObject(TestH5NetCDFData, FileObjectNetCDF): engine: T_NetcdfEngine = "h5netcdf" def test_open_badbytes(self) -> None: @@ -4609,8 +4635,10 @@ def test_open_badbytes(self) -> None: ): with open_dataset(b"garbage"): pass - with pytest.raises(ValueError, match=r"can only read bytes"): - with open_dataset(b"garbage", engine="netcdf4"): + with pytest.raises( + ValueError, match=r"not the signature of a valid netCDF4 file" + ): + with open_dataset(b"garbage", engine="h5netcdf"): pass with pytest.raises( ValueError, match=r"not the signature of a valid netCDF4 file" @@ -4618,16 +4646,6 @@ def test_open_badbytes(self) -> None: with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"): pass - def test_open_twice(self) -> None: - expected = create_test_data() - expected.attrs["foo"] = "bar" - with create_tmp_file() as tmp_file: - expected.to_netcdf(tmp_file, engine="h5netcdf") - with open(tmp_file, "rb") as f: - with open_dataset(f, engine="h5netcdf"): - with open_dataset(f, engine="h5netcdf"): - pass - @requires_scipy def test_open_fileobj(self) -> None: # open in-memory datasets instead of local file paths @@ -4661,31 +4679,10 @@ def test_open_fileobj(self) -> None: with open_dataset(f): # ensure file gets closed pass - def test_file_remains_open(self) -> None: - data = Dataset({"foo": ("x", [1, 2, 3])}) - f = BytesIO() - data.to_netcdf(f, engine="h5netcdf") - assert not f.closed - restored = open_dataset(f, engine="h5netcdf") - assert not f.closed - assert_identical(restored, data) - restored.close() - assert not f.closed - @requires_h5netcdf -class TestH5NetCDFInMemoryData: - def test_roundtrip_via_bytes(self) -> None: - original = create_test_data() - netcdf_bytes = original.to_netcdf(engine="h5netcdf") - roundtrip = open_dataset(netcdf_bytes, engine="h5netcdf") - assert_identical(roundtrip, original) - - def test_roundtrip_group_via_bytes(self) -> None: - original = create_test_data() - netcdf_bytes = original.to_netcdf(group="sub", engine="h5netcdf") - roundtrip = open_dataset(netcdf_bytes, group="sub", engine="h5netcdf") - assert_identical(roundtrip, original) +class TestH5NetCDFInMemoryData(InMemoryNetCDFWithGroups): + engine: T_NetcdfEngine = "h5netcdf" @requires_h5netcdf @@ -6052,17 +6049,6 @@ def test_open_dataarray_options(self) -> None: with open_dataarray(tmp, drop_variables=["y"]) as loaded: assert_identical(expected, loaded) - @requires_scipy - def test_dataarray_to_netcdf_return_bytes(self) -> None: - # regression test for GH1410 - data = xr.DataArray([1, 2, 3]) - with pytest.warns( - FutureWarning, - match=re.escape("return value of to_netcdf() without a target"), - ): - output = data.to_netcdf(engine="scipy") - assert isinstance(output, bytes) - def test_dataarray_to_netcdf_no_name_pathlib(self) -> None: original_da = DataArray(np.arange(12).reshape((3, 4))) @@ -6579,6 +6565,9 @@ def test_netcdf4_entrypoint(tmp_path: Path) -> None: assert entrypoint.guess_can_open("something-local.cdf") assert not entrypoint.guess_can_open("not-found-and-no-extension") + contents = ds.to_netcdf(engine="netcdf4") + _check_guess_can_open_and_open(entrypoint, contents, engine="netcdf4", expected=ds) + path = tmp_path / "baz" with open(path, "wb") as f: f.write(b"not-a-netcdf-file") @@ -6597,10 +6586,7 @@ def test_scipy_entrypoint(tmp_path: Path) -> None: with open(path, "rb") as f: _check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds) - with pytest.warns( - FutureWarning, match=re.escape("return value of to_netcdf() without a target") - ): - contents = ds.to_netcdf(engine="scipy") + contents = ds.to_netcdf(engine="scipy") _check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds) _check_guess_can_open_and_open( entrypoint, BytesIO(contents), engine="scipy", expected=ds @@ -6632,6 +6618,9 @@ def test_h5netcdf_entrypoint(tmp_path: Path) -> None: with open(path, "rb") as f: _check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds) + contents = ds.to_netcdf(engine="h5netcdf") + _check_guess_can_open_and_open(entrypoint, contents, engine="h5netcdf", expected=ds) + assert entrypoint.guess_can_open("something-local.nc") assert entrypoint.guess_can_open("something-local.nc4") assert entrypoint.guess_can_open("something-local.cdf") diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index ed487b07450..8eb7cbc8803 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -7,7 +7,7 @@ import pytest import xarray as xr -from xarray.backends.api import _get_default_engine, _get_default_engine_netcdf +from xarray.backends.api import _get_default_engine, get_default_engine_netcdf from xarray.tests import ( assert_identical, assert_no_warnings, @@ -39,7 +39,7 @@ def test_default_engine_h5netcdf(monkeypatch): monkeypatch.delitem(sys.modules, "scipy", raising=False) monkeypatch.setattr(sys, "meta_path", []) - assert _get_default_engine_netcdf() == "h5netcdf" + assert get_default_engine_netcdf(format=None) == "h5netcdf" def test_custom_engine() -> None: diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index ec57993c4b2..4a65b50d448 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -10,8 +10,7 @@ import pytest import xarray as xr -from xarray.backends.api import open_datatree, open_groups -from xarray.core.datatree import DataTree +from xarray import DataTree, load_datatree, open_datatree, open_groups from xarray.testing import assert_equal, assert_identical from xarray.tests import ( has_zarr_v3, @@ -19,6 +18,7 @@ parametrize_zarr_format, requires_dask, requires_h5netcdf, + requires_h5netcdf_or_netCDF4, requires_netCDF4, requires_pydap, requires_zarr, @@ -265,6 +265,21 @@ def test_write_subgroup(self, tmpdir): assert_equal(original_dt, roundtrip_dt) assert_identical(expected_dt, roundtrip_dt) + def test_roundtrip_via_memoryview_engine_specified(self, simple_datatree): + original_dt = simple_datatree + roundtrip_dt = load_datatree( + original_dt.to_netcdf(engine=self.engine), engine=self.engine + ) + assert_equal(original_dt, roundtrip_dt) + + +@requires_h5netcdf_or_netCDF4 +class TestGenericNetCDFIO: + def test_roundtrip_via_memoryview(self, simple_datatree): + original_dt = simple_datatree + roundtrip_dt = load_datatree(original_dt.to_netcdf()) + assert_equal(original_dt, roundtrip_dt) + @requires_netCDF4 class TestNetCDF4DatatreeIO(DatatreeIOBase): @@ -539,16 +554,6 @@ def test_phony_dims_warning(self, tmpdir) -> None: "phony_dim_3": 25, } - def test_roundtrip_via_bytes(self, simple_datatree): - original_dt = simple_datatree - roundtrip_dt = open_datatree(original_dt.to_netcdf()) - assert_equal(original_dt, roundtrip_dt) - - def test_roundtrip_via_bytes_engine_specified(self, simple_datatree): - original_dt = simple_datatree - roundtrip_dt = open_datatree(original_dt.to_netcdf(engine=self.engine)) - assert_equal(original_dt, roundtrip_dt) - def test_roundtrip_using_filelike_object(self, tmpdir, simple_datatree): original_dt = simple_datatree filepath = tmpdir + "/test.nc" @@ -575,6 +580,9 @@ def test_to_zarr(self, tmpdir, simple_datatree, zarr_format): with open_datatree(filepath, engine="zarr") as roundtrip_dt: assert_equal(original_dt, roundtrip_dt) + @pytest.mark.filterwarnings( + "ignore:Numcodecs codecs are not in the Zarr version 3 specification" + ) def test_zarr_encoding(self, tmpdir, simple_datatree, zarr_format): filepath = str(tmpdir / "test.zarr") original_dt = simple_datatree @@ -601,11 +609,10 @@ def test_zarr_encoding(self, tmpdir, simple_datatree, zarr_format): enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item] with pytest.raises(ValueError, match="unexpected encoding group.*"): - original_dt.to_zarr( - filepath, encoding=enc, engine="zarr", zarr_format=zarr_format - ) + original_dt.to_zarr(filepath, encoding=enc, zarr_format=zarr_format) @pytest.mark.xfail(reason="upstream zarr read-only changes have broken this test") + @pytest.mark.filterwarnings("ignore:Duplicate name") def test_to_zarr_zip_store(self, tmpdir, simple_datatree, zarr_format): from zarr.storage import ZipStore @@ -653,7 +660,9 @@ def test_to_zarr_compute_false( storepath = tmp_path / "test.zarr" original_dt = simple_datatree.chunk() - original_dt.to_zarr(str(storepath), compute=False, zarr_format=zarr_format) + result = original_dt.to_zarr( + str(storepath), compute=False, zarr_format=zarr_format + ) def assert_expected_zarr_files_exist( arr_dir: Path, @@ -724,6 +733,13 @@ def assert_expected_zarr_files_exist( zarr_format=zarr_format, ) + in_progress_dt = load_datatree(str(storepath), engine="zarr") + assert not in_progress_dt.equals(original_dt) + + result.compute() # type: ignore[union-attr] + written_dt = load_datatree(str(storepath), engine="zarr") + assert_identical(written_dt, original_dt) + def test_to_zarr_inherited_coords(self, tmpdir, zarr_format): original_dt = DataTree.from_dict( { diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index ab1ac4a06d9..97b65f876b9 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -7,7 +7,7 @@ import pytest -from xarray.backends.file_manager import CachingFileManager +from xarray.backends.file_manager import CachingFileManager, PickleableFileManager from xarray.backends.lru_cache import LRUCache from xarray.core.options import set_options from xarray.tests import assert_no_warnings @@ -262,3 +262,19 @@ class AcquisitionError(Exception): assert file_cache # file *was* already open manager.close() + + +def test_pickleable_file_manager_write_pickle(tmpdir) -> None: + path = str(tmpdir.join("testing.txt")) + manager = PickleableFileManager(open, path, mode="w") + f = manager.acquire() + f.write("foo") + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write("bar") + manager2.close() + manager.close() + + with open(path) as f: + assert f.read() == "foobar" From cce54776bcd50cd324e1cf1a75c719fb3474af67 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 11 Aug 2025 15:29:49 -0700 Subject: [PATCH 02/37] Refactor to_netcdf() and to_zarr() internals --- doc/whats-new.rst | 3 + xarray/backends/api.py | 221 ++++++++++++++++--------- xarray/backends/common.py | 9 + xarray/backends/h5netcdf_.py | 13 +- xarray/backends/locks.py | 4 +- xarray/backends/netCDF4_.py | 13 +- xarray/backends/zarr.py | 24 ++- xarray/core/datatree.py | 63 ++++++- xarray/core/datatree_io.py | 169 ++++++++++++++----- xarray/tests/test_backends_datatree.py | 33 +++- 10 files changed, 414 insertions(+), 138 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 78ef2875b31..438c4665bc5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,6 +15,9 @@ New Features - :py:meth:`DataTree.to_netcdf` can now write to a file-like object, or return bytes if called without a filepath. (:issue:`10570`) By `Matthew Willson `_. +- ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and + :py:meth:`DataTree.to_zarr`. + By `Stephan Hoyer `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2a6476ea828..6ad67518639 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -30,13 +30,14 @@ from xarray.backends import plugins from xarray.backends.common import ( AbstractDataStore, + AbstractWritableDataStore, ArrayWriter, BytesIOProxy, T_PathFileOrDataStore, _find_absolute_paths, _normalize_path, ) -from xarray.backends.locks import _get_scheduler +from xarray.backends.locks import get_dask_scheduler from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder from xarray.core import dtypes, indexing from xarray.core.coordinates import Coordinates @@ -283,12 +284,18 @@ def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None: _protect_dataset_variables_inplace(node, cache) -def _finalize_store(write, store): +def _finalize_store(writes, store): """Finalize this store by explicitly syncing and closing""" - del write # ensure writing is done first + del writes # ensure writing is done first store.close() +def delayed_close_after_writes(writes, store): + import dask + + return dask.delayed(_finalize_store)(writes, store) + + def _multi_file_closer(closers): for closer in closers: closer() @@ -1769,6 +1776,39 @@ def open_mfdataset( } +def get_writable_netcdf_store( + target, + engine: T_NetcdfEngine, + *, + format: T_NetcdfTypes | None, + mode: NetcdfWriteModes, + autoclose: bool, + invalid_netcdf: bool, + auto_complex: bool | None, +) -> AbstractWritableDataStore: + """Create a store for writing to a netCDF file.""" + try: + store_open = WRITEABLE_STORES[engine] + except KeyError as err: + raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err + + if format is not None: + format = format.upper() # type: ignore[assignment] + + kwargs = dict(autoclose=True) if autoclose else {} + if invalid_netcdf: + if engine == "h5netcdf": + kwargs["invalid_netcdf"] = invalid_netcdf + else: + raise ValueError( + f"unrecognized option 'invalid_netcdf' for engine {engine}" + ) + if auto_complex is not None: + kwargs["auto_complex"] = auto_complex + + return store_open(target, mode=mode, format=format, **kwargs) + + # multifile=True returns writer and datastore @overload def to_netcdf( @@ -1952,16 +1992,8 @@ def to_netcdf( _validate_dataset_names(dataset) _validate_attrs(dataset, engine, invalid_netcdf) - try: - store_open = WRITEABLE_STORES[engine] - except KeyError as err: - raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err - - if format is not None: - format = format.upper() # type: ignore[assignment] - # handle scheduler specific logic - scheduler = _get_scheduler() + scheduler = get_dask_scheduler() have_chunks = any(v.chunks is not None for v in dataset.variables.values()) autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] @@ -1976,26 +2008,17 @@ def to_netcdf( else: target = path_or_file # type: ignore[assignment] - kwargs = dict(autoclose=True) if autoclose else {} - if invalid_netcdf: - if engine == "h5netcdf": - kwargs["invalid_netcdf"] = invalid_netcdf - else: - raise ValueError( - f"unrecognized option 'invalid_netcdf' for engine {engine}" - ) - if auto_complex is not None: - kwargs["auto_complex"] = auto_complex - - store = store_open(target, mode, format, group, **kwargs) - - if unlimited_dims is None: - unlimited_dims = dataset.encoding.get("unlimited_dims", None) - if unlimited_dims is not None: - if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): - unlimited_dims = [unlimited_dims] - else: - unlimited_dims = list(unlimited_dims) + store = get_writable_netcdf_store( + target, + engine, + mode=mode, + format=format, + autoclose=autoclose, + invalid_netcdf=invalid_netcdf, + auto_complex=auto_complex, + ) + if group is not None: + store = store.get_child_store(group) writer = ArrayWriter() @@ -2016,17 +2039,18 @@ def to_netcdf( writes = writer.sync(compute=compute) finally: - if not multifile and compute: # type: ignore[redundant-expr] - store.close() + if not multifile: + if compute: + store.close() + else: + store.sync() if path_or_file is None: assert isinstance(target, BytesIOProxy) # created in this function return target.getvalue_or_getbuffer() if not compute: - import dask - - return dask.delayed(_finalize_store)(writes, store) + return delayed_close_after_writes(writes, store) return None @@ -2041,6 +2065,15 @@ def dump_to_store( if encoding is None: encoding = {} + if unlimited_dims is None: + unlimited_dims = dataset.encoding.get("unlimited_dims", None) + + if unlimited_dims is not None: + if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): + unlimited_dims = [unlimited_dims] + else: + unlimited_dims = list(unlimited_dims) + variables, attrs = conventions.encode_dataset_coordinates(dataset) check_encoding = set() @@ -2182,20 +2215,71 @@ def save_mfdataset( try: writes = [w.sync(compute=compute) for w in writers] finally: - if compute: - for store in stores: + for store in stores: + if compute: store.close() + else: + store.sync() if not compute: import dask return dask.delayed( - list( - starmap(dask.delayed(_finalize_store), zip(writes, stores, strict=True)) - ) + list(starmap(delayed_close_after_writes, zip(writes, stores, strict=True))) ) +def get_writable_zarr_store( + store: ZarrStoreLike | None = None, + *, + chunk_store: MutableMapping | str | os.PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + align_chunks: bool = False, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + zarr_format: int | None = None, + write_empty_chunks: bool | None = None, +) -> backends.ZarrStore: + """Create a store for writing to Zarr.""" + from xarray.backends.zarr import _choose_default_mode, _get_mappers + + kwargs, mapper, chunk_mapper = _get_mappers( + storage_options=storage_options, store=store, chunk_store=chunk_store + ) + mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region) + + if mode == "r+": + already_consolidated = consolidated + consolidate_on_close = False + else: + already_consolidated = False + consolidate_on_close = consolidated or consolidated is None + + return backends.ZarrStore.open_group( + store=mapper, + mode=mode, + synchronizer=synchronizer, + group=group, + consolidated=already_consolidated, + consolidate_on_close=consolidate_on_close, + chunk_store=chunk_mapper, + append_dim=append_dim, + write_region=region, + safe_chunks=safe_chunks, + align_chunks=align_chunks, + zarr_version=zarr_version, + zarr_format=zarr_format, + write_empty=write_empty_chunks, + **kwargs, + ) + + # compute=True returns ZarrStore @overload def to_zarr( @@ -2270,7 +2354,6 @@ def to_zarr( See `Dataset.to_zarr` for full API docs. """ - from xarray.backends.zarr import _choose_default_mode, _get_mappers # validate Dataset keys, DataArray names _validate_dataset_names(dataset) @@ -2285,53 +2368,39 @@ def to_zarr( if encoding is None: encoding = {} - kwargs, mapper, chunk_mapper = _get_mappers( - storage_options=storage_options, store=store, chunk_store=chunk_store - ) - mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region) - - if mode == "r+": - already_consolidated = consolidated - consolidate_on_close = False - else: - already_consolidated = False - consolidate_on_close = consolidated or consolidated is None - - zstore = backends.ZarrStore.open_group( - store=mapper, + zstore = get_writable_zarr_store( + store, + chunk_store=chunk_store, mode=mode, synchronizer=synchronizer, group=group, - consolidated=already_consolidated, - consolidate_on_close=consolidate_on_close, - chunk_store=chunk_mapper, + consolidated=consolidated, append_dim=append_dim, - write_region=region, + region=region, safe_chunks=safe_chunks, align_chunks=align_chunks, + storage_options=storage_options, zarr_version=zarr_version, zarr_format=zarr_format, - write_empty=write_empty_chunks, - **kwargs, + write_empty_chunks=write_empty_chunks, ) - dataset = zstore._validate_and_autodetect_region( - dataset, - ) + dataset = zstore._validate_and_autodetect_region(dataset) zstore._validate_encoding(encoding) writer = ArrayWriter() - # TODO: figure out how to properly handle unlimited_dims - dump_to_store(dataset, zstore, writer, encoding=encoding) - writes = writer.sync( - compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs - ) - if compute: - _finalize_store(writes, zstore) - else: - import dask + # TODO: figure out how to properly handle unlimited_dims + try: + dump_to_store(dataset, zstore, writer, encoding=encoding) + writes = writer.sync( + compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs + ) + finally: + if compute: + zstore.close() - return dask.delayed(_finalize_store)(writes, zstore) + if not compute: + return delayed_close_after_writes(writes, zstore) return zstore diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 542ca4c897b..c4daef182f6 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -12,6 +12,7 @@ Any, ClassVar, Generic, + Self, TypeVar, Union, overload, @@ -305,6 +306,10 @@ def get_duck_array(self, dtype: np.typing.DTypeLike = None): class AbstractDataStore: __slots__ = () + def get_child_store(self, group: str) -> Self: # pragma: no cover + """Get a store corresponding to the indicated child group.""" + raise NotImplementedError() + def get_dimensions(self): # pragma: no cover raise NotImplementedError() @@ -581,6 +586,10 @@ def set_dimensions(self, variables, unlimited_dims=None): is_unlimited = dim in unlimited_dims self.set_dimension(dim, length, is_unlimited) + def sync(self): + """Write all buffered data to disk.""" + raise NotImplementedError() + def _infer_dtype(array, name=None): """Given an object array with no missing values, infer its dtype from all elements.""" diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 24a3324bf62..d842be967eb 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -4,7 +4,7 @@ import io import os from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Self import numpy as np @@ -149,6 +149,17 @@ def __init__( self.lock = ensure_lock(lock) self.autoclose = autoclose + def get_child_store(self, group: str) -> Self: + if self._group is not None: + group = os.path.join(self._group, group) + return type(self)( + self._manager, + group=group, + mode=self._mode, + lock=self.lock, + autoclose=self.autoclose, + ) + @classmethod def open( cls, diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index c6a06dd714e..82d3e0b7dae 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -134,7 +134,7 @@ def _get_lock_maker(scheduler=None): raise KeyError(scheduler) -def _get_scheduler(get=None, collection=None) -> str | None: +def get_dask_scheduler(get=None, collection=None) -> str | None: """Determine the dask scheduler that is being used. None is returned if no dask scheduler is active. @@ -184,7 +184,7 @@ def get_write_lock(key): ------- Lock object that can be used like a threading.Lock object. """ - scheduler = _get_scheduler() + scheduler = get_dask_scheduler() lock_maker = _get_lock_maker(scheduler) return lock_maker(key) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index ab1841461f4..040c53a626b 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -5,7 +5,7 @@ import os from collections.abc import Iterable from contextlib import suppress -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Self import numpy as np @@ -400,6 +400,17 @@ def __init__( self.lock = ensure_lock(lock) self.autoclose = autoclose + def get_child_store(self, group: str) -> Self: + if self._group is not None: + group = os.path.join(self._group, group) + return type(self)( + self._manager, + group=group, + mode=self._mode, + lock=self.lock, + autoclose=self.autoclose, + ) + @classmethod def open( cls, diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 1b62a87d10c..456c0182da1 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -5,7 +5,7 @@ import os import struct from collections.abc import Hashable, Iterable, Mapping -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, Self, cast import numpy as np import pandas as pd @@ -735,6 +735,22 @@ def __init__( # on demand. self._members = self._fetch_members() + def get_child_store(self, group: str) -> Self: + zarr_group = self.zarr_group.require_group(group) + return type(self)( + zarr_group=zarr_group, + mode=self._mode, + consolidate_on_close=self._consolidate_on_close, + append_dim=self._append_dim, + write_region=self._write_region, + safe_chunks=self._safe_chunks, + write_empty=self._write_empty, + close_store_on_close=self._close_store_on_close, + use_zarr_fill_value_as_mask=self._use_zarr_fill_value_as_mask, + align_chunks=self._align_chunks, + cache_members=self._cache_members, + ) + @property def members(self) -> dict[str, ZarrArray | ZarrGroup]: """ @@ -996,9 +1012,6 @@ def store( kwargs["zarr_format"] = self.zarr_group.metadata.zarr_format zarr.consolidate_metadata(self.zarr_group.store, **kwargs) - def sync(self): - pass - def _open_existing_array(self, *, name) -> ZarrArray: import zarr from zarr import Array as ZarrArray @@ -1216,6 +1229,9 @@ def set_variables( writer.add(v.data, zarr_array, region) + def sync(self) -> None: + pass + def close(self) -> None: if self._close_store_on_close: self.zarr_group.store.close() diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bf82baccb31..6784ce179e2 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -18,6 +18,7 @@ TYPE_CHECKING, Any, Concatenate, + Literal, NoReturn, ParamSpec, TypeVar, @@ -75,7 +76,9 @@ if TYPE_CHECKING: import numpy as np import pandas as pd + from dask.delayed import Delayed + from xarray.backends import ZarrStore from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.types import ( Dims, @@ -1677,6 +1680,7 @@ def to_netcdf( **kwargs, ) -> memoryview: ... + # compute=False returns dask.Delayed @overload def to_netcdf( self, @@ -1688,7 +1692,24 @@ def to_netcdf( engine: T_DataTreeNetcdfEngine | None = None, group: str | None = None, write_inherited_coords: bool = False, - compute: bool = True, + *, + compute: Literal[False], + **kwargs, + ) -> Delayed: ... + + # default return None + @overload + def to_netcdf( + self, + filepath: str | PathLike | io.IOBase, + mode: NetcdfWriteModes = "w", + encoding=None, + unlimited_dims=None, + format: T_DataTreeNetcdfTypes | None = None, + engine: T_DataTreeNetcdfEngine | None = None, + group: str | None = None, + write_inherited_coords: bool = False, + compute: Literal[True] = True, **kwargs, ) -> None: ... @@ -1704,7 +1725,7 @@ def to_netcdf( write_inherited_coords: bool = False, compute: bool = True, **kwargs, - ) -> None | memoryview: + ) -> None | memoryview | Delayed: """ Write datatree contents to a netCDF file. @@ -1748,13 +1769,13 @@ def to_netcdf( compute : bool, default: True If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. - Currently, ``compute=False`` is not supported. kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` Returns ------- * ``memoryview`` if path is None + * ``dask.delayed.Delayed`` if compute is False * ``None`` otherwise Note @@ -1778,6 +1799,35 @@ def to_netcdf( **kwargs, ) + # compute=False returns dask.Delayed + @overload + def to_zarr( + self, + store, + mode: ZarrWriteModes = "w-", + encoding=None, + consolidated: bool = True, + group: str | None = None, + write_inherited_coords: bool = False, + *, + compute: Literal[False], + **kwargs, + ) -> Delayed: ... + + # default returns ZarrStore + @overload + def to_zarr( + self, + store, + mode: ZarrWriteModes = "w-", + encoding=None, + consolidated: bool = True, + group: str | None = None, + write_inherited_coords: bool = False, + compute: Literal[True] = True, + **kwargs, + ) -> ZarrStore: ... + def to_zarr( self, store, @@ -1788,7 +1838,7 @@ def to_zarr( write_inherited_coords: bool = False, compute: bool = True, **kwargs, - ): + ) -> ZarrStore | Delayed: """ Write datatree contents to a Zarr store. @@ -1819,8 +1869,7 @@ def to_zarr( compute : bool, default: True If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. Metadata - is always updated eagerly. Currently, ``compute=False`` is not - supported. + is always updated eagerly. kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` @@ -1831,7 +1880,7 @@ def to_zarr( """ from xarray.core.datatree_io import _datatree_to_zarr - _datatree_to_zarr( + return _datatree_to_zarr( self, store, mode=mode, diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index c586caaba89..052d3f3b45c 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -1,10 +1,18 @@ from __future__ import annotations import io -from collections.abc import Mapping +from collections.abc import Hashable, Mapping, MutableMapping from os import PathLike from typing import TYPE_CHECKING, Any, Literal, get_args +from xarray.backends.api import ( + delayed_close_after_writes, + dump_to_store, + get_writable_netcdf_store, + get_writable_zarr_store, +) +from xarray.backends.common import ArrayWriter +from xarray.backends.locks import get_dask_scheduler from xarray.core.datatree import DataTree from xarray.core.types import NetcdfWriteModes, ZarrWriteModes @@ -12,6 +20,9 @@ T_DataTreeNetcdfTypes = Literal["NETCDF4"] if TYPE_CHECKING: + from dask.delayed import Delayed + + from xarray.backends import ZarrStore from xarray.core.types import ZarrStoreLike @@ -26,7 +37,8 @@ def _datatree_to_netcdf( group: str | None = None, write_inherited_coords: bool = False, compute: bool = True, - **kwargs, + invalid_netcdf: bool = False, + auto_complex: bool | None = None, ) -> None | memoryview: """Implementation of `DataTree.to_netcdf`.""" @@ -46,9 +58,6 @@ def _datatree_to_netcdf( "specifying a root group for the tree has not been implemented" ) - if not compute: - raise NotImplementedError("compute=False has not been implemented yet") - if encoding is None: encoding = {} @@ -70,27 +79,62 @@ def _datatree_to_netcdf( if unlimited_dims is None: unlimited_dims = {} - for node in dt.subtree: - at_root = node is dt - ds = node.to_dataset(inherit=write_inherited_coords or at_root) - group_path = None if at_root else "/" + node.relative_to(dt) - ds.to_netcdf( - target, - group=group_path, - mode=mode, - encoding=encoding.get(node.path), - unlimited_dims=unlimited_dims.get(node.path), - engine=engine, - format=format, - compute=compute, - **kwargs, - ) - mode = "a" + scheduler = get_dask_scheduler() + have_chunks = any( + v.chunks is not None for node in dt.subtree for v in node.variables.values() + ) + autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] + + root_store = get_writable_netcdf_store( + target, + engine, # type: ignore[arg-type] + mode=mode, + format=format, + autoclose=autoclose, + invalid_netcdf=invalid_netcdf, + auto_complex=auto_complex, + ) + if group is not None: + root_store = root_store.get_child(group) + + writer = ArrayWriter() + + try: + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + + for node in dt.subtree: + at_root = node is dt + dataset = node.to_dataset(inherit=write_inherited_coords or at_root) + node_store = ( + root_store if at_root else root_store.get_child_store(node.path) + ) + dump_to_store( + dataset, + node_store, + writer, + encoding=encoding.get(node.path), + unlimited_dims=unlimited_dims.get(node.path), + ) + + if autoclose: + root_store.close() + + writes = writer.sync(compute=compute) + + finally: + if compute: + root_store.close() + else: + root_store.sync() if filepath is None: assert isinstance(target, io.BytesIO) return target.getbuffer() + if not compute: + return delayed_close_after_writes(writes, root_store) + return None @@ -99,22 +143,31 @@ def _datatree_to_zarr( store: ZarrStoreLike, mode: ZarrWriteModes = "w-", encoding: Mapping[str, Any] | None = None, - consolidated: bool = True, + synchronizer=None, group: str | None = None, write_inherited_coords: bool = False, + *, + chunk_store: MutableMapping | str | PathLike | None = None, compute: bool = True, - **kwargs, -): + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + align_chunks: bool = False, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + zarr_format: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, +) -> ZarrStore | Delayed: """Implementation of `DataTree.to_zarr`.""" - from zarr import consolidate_metadata - if group is not None: raise NotImplementedError( "specifying a root group for the tree has not been implemented" ) - if "append_dim" in kwargs: + if append_dim is not None: raise NotImplementedError( "specifying ``append_dim`` with ``DataTree.to_zarr`` has not been implemented" ) @@ -130,21 +183,51 @@ def _datatree_to_zarr( f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" ) - for node in dt.subtree: - at_root = node is dt - ds = node.to_dataset(inherit=write_inherited_coords or at_root) - group_path = None if at_root else "/" + node.relative_to(dt) - ds.to_zarr( - store, - group=group_path, - mode=mode, - encoding=encoding.get(node.path), - consolidated=False, - compute=compute, - **kwargs, + root_store = get_writable_zarr_store( + store, + chunk_store=chunk_store, + mode=mode, + synchronizer=synchronizer, + group=group, + consolidated=consolidated, + append_dim=append_dim, + region=region, + safe_chunks=safe_chunks, + align_chunks=align_chunks, + storage_options=storage_options, + zarr_version=zarr_version, + zarr_format=zarr_format, + write_empty_chunks=write_empty_chunks, + ) + + writer = ArrayWriter() + + # TODO: figure out how to properly handle unlimited_dims + try: + for node in dt.subtree: + at_root = node is dt + dataset = node.to_dataset(inherit=write_inherited_coords or at_root) + node_store = ( + root_store if at_root else root_store.get_child_store(node.path) + ) + + dataset = node_store._validate_and_autodetect_region(dataset) + node_store._validate_encoding(encoding) + + dump_to_store( + dataset, + node_store, + writer, + encoding=encoding.get(node.path), + ) + writes = writer.sync( + compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs ) - if "w" in mode: - mode = "a" + finally: + if compute: + root_store.close() + + if not compute: + return delayed_close_after_writes(writes, root_store) - if consolidated: - consolidate_metadata(store) + return root_store diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index ec57993c4b2..45d65017172 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -265,6 +265,19 @@ def test_write_subgroup(self, tmpdir): assert_equal(original_dt, roundtrip_dt) assert_identical(expected_dt, roundtrip_dt) + def test_compute_false(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.nc" + original_dt = simple_datatree.chunk() + result = original_dt.to_netcdf(filepath, engine=self.engine, compute=False) + + with open_datatree(filepath, engine=self.engine) as in_progress_dt: + assert in_progress_dt.isomorphic(original_dt) + assert not in_progress_dt.equals(original_dt) + + result.compute() + with open_datatree(filepath, engine=self.engine) as written_dt: + assert_identical(written_dt, original_dt) + @requires_netCDF4 class TestNetCDF4DatatreeIO(DatatreeIOBase): @@ -575,6 +588,9 @@ def test_to_zarr(self, tmpdir, simple_datatree, zarr_format): with open_datatree(filepath, engine="zarr") as roundtrip_dt: assert_equal(original_dt, roundtrip_dt) + @pytest.mark.filterwarnings( + "ignore:Numcodecs codecs are not in the Zarr version 3 specification" + ) def test_zarr_encoding(self, tmpdir, simple_datatree, zarr_format): filepath = str(tmpdir / "test.zarr") original_dt = simple_datatree @@ -601,11 +617,10 @@ def test_zarr_encoding(self, tmpdir, simple_datatree, zarr_format): enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item] with pytest.raises(ValueError, match="unexpected encoding group.*"): - original_dt.to_zarr( - filepath, encoding=enc, engine="zarr", zarr_format=zarr_format - ) + original_dt.to_zarr(filepath, encoding=enc, zarr_format=zarr_format) @pytest.mark.xfail(reason="upstream zarr read-only changes have broken this test") + @pytest.mark.filterwarnings("ignore:Duplicate name") def test_to_zarr_zip_store(self, tmpdir, simple_datatree, zarr_format): from zarr.storage import ZipStore @@ -653,7 +668,9 @@ def test_to_zarr_compute_false( storepath = tmp_path / "test.zarr" original_dt = simple_datatree.chunk() - original_dt.to_zarr(str(storepath), compute=False, zarr_format=zarr_format) + result = original_dt.to_zarr( + str(storepath), compute=False, zarr_format=zarr_format + ) def assert_expected_zarr_files_exist( arr_dir: Path, @@ -724,6 +741,14 @@ def assert_expected_zarr_files_exist( zarr_format=zarr_format, ) + with open_datatree(str(storepath), engine="zarr") as in_progress_dt: + assert in_progress_dt.isomorphic(original_dt) + assert not in_progress_dt.equals(original_dt) + + result.compute() + with open_datatree(str(storepath), engine="zarr") as written_dt: + assert_identical(written_dt, original_dt) + def test_to_zarr_inherited_coords(self, tmpdir, zarr_format): original_dt = DataTree.from_dict( { From 22d3387b9ff789c369de9ba69a5ee85a42816618 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 13 Aug 2025 14:31:38 -0700 Subject: [PATCH 03/37] Fixes per review --- xarray/core/datatree_io.py | 2 -- xarray/tests/test_backends_datatree.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index 052d3f3b45c..edf0cad3520 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -94,8 +94,6 @@ def _datatree_to_netcdf( invalid_netcdf=invalid_netcdf, auto_complex=auto_complex, ) - if group is not None: - root_store = root_store.get_child(group) writer = ArrayWriter() diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index f7f3ab1d29c..fdcfee1f1cd 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -282,6 +282,7 @@ def test_no_redundant_dimensions(self, tmpdir): assert list(root.dimensions) == ["x"] assert list(child.dimensions) == [] + @requires_dask def test_compute_false(self, tmpdir, simple_datatree): filepath = tmpdir / "test.nc" original_dt = simple_datatree.chunk() From e68e18689e609f52ce481f64145b4d1bd7010bea Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 13 Aug 2025 15:34:35 -0700 Subject: [PATCH 04/37] Clean up comments --- xarray/core/datatree_io.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index edf0cad3520..eb863c60d3d 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -97,10 +97,9 @@ def _datatree_to_netcdf( writer = ArrayWriter() + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask try: - # TODO: allow this work (setting up the file for writing array data) - # to be parallelized with dask - for node in dt.subtree: at_root = node is dt dataset = node.to_dataset(inherit=write_inherited_coords or at_root) @@ -200,7 +199,6 @@ def _datatree_to_zarr( writer = ArrayWriter() - # TODO: figure out how to properly handle unlimited_dims try: for node in dt.subtree: at_root = node is dt From 6d8ae1ecfa57dade14dcf264af79f7d4d958699d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 14 Aug 2025 21:45:13 -0700 Subject: [PATCH 05/37] Fix type for to_netcdf() --- xarray/core/datatree_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index eb863c60d3d..f4bd28af5eb 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -39,7 +39,7 @@ def _datatree_to_netcdf( compute: bool = True, invalid_netcdf: bool = False, auto_complex: bool | None = None, -) -> None | memoryview: +) -> None | memoryview | Delayed: """Implementation of `DataTree.to_netcdf`.""" if format not in [None, *get_args(T_DataTreeNetcdfTypes)]: From d9da973dee24e1b40d12544cc966723dc8e55162 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 14 Aug 2025 22:36:36 -0700 Subject: [PATCH 06/37] Add test and whats-new for cross-group redundant computation --- doc/whats-new.rst | 10 +++++++--- xarray/tests/test_backends_datatree.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b348dcbd370..662ff472e94 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,9 @@ v2025.08.1 (unreleased) New Features ~~~~~~~~~~~~ +- ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and + :py:meth:`DataTree.to_zarr`. + By `Stephan Hoyer `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -25,6 +28,10 @@ Deprecations Bug fixes ~~~~~~~~~ +- :py:meth:`DataTree.to_netcdf` and :py:meth:`DataTree.to_zarr` with avoid + redundant computation of Dask arrays with cross-group dependencies + (:issue:`10637`). + By `Stephan Hoyer `_. Documentation ~~~~~~~~~~~~~ @@ -54,9 +61,6 @@ New Features (:issue:`10326`, :pull:`10327`) By `Tom Nicholas `_. - :py:meth:`DataTree.to_netcdf` can now write to a file-like object, or return bytes if called without a filepath. (:issue:`10570`) By `Matthew Willson `_. -- ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and - :py:meth:`DataTree.to_zarr`. - By `Stephan Hoyer `_. - Added exception handling for invalid files in :py:func:`open_mfdataset`. (:issue:`6736`) By `Pratiman Patel `_. diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 000cd9c02fa..2e65f2c4903 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -767,6 +767,31 @@ def assert_expected_zarr_files_exist( with open_datatree(str(storepath), engine="zarr") as written_dt: assert_identical(written_dt, original_dt) + @requires_dask + def test_to_zarr_no_redundant_computation(self, tmpdir, zarr_format): + import dask.array as da + + eval_count = 0 + + def expensive_func(x): + nonlocal eval_count + eval_count += 1 + return x + 1 + + base = da.random.random((), chunks=()) + derived1 = da.map_blocks(expensive_func, base, meta=np.array((), np.float64)) + derived2 = derived1 + 1 # depends on derived1 + tree = DataTree.from_dict( + { + "group1": xr.Dataset({"derived": derived1}), + "group2": xr.Dataset({"derived": derived2}), + } + ) + + filepath = str(tmpdir / "test.zarr") + tree.to_zarr(filepath, zarr_format=zarr_format) + assert eval_count == 1 # not 2 + def test_to_zarr_inherited_coords(self, tmpdir, zarr_format): original_dt = DataTree.from_dict( { From 205fdbe5db093bc74d1ec422f8acafcad49200cf Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 15 Aug 2025 17:19:38 -0700 Subject: [PATCH 07/37] Fix test failure on CI (and add a better test) --- xarray/backends/api.py | 9 --------- xarray/tests/test_backends.py | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index c1516e2e9ad..15064e10502 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -2153,15 +2153,6 @@ def dump_to_store( if encoding is None: encoding = {} - if unlimited_dims is None: - unlimited_dims = dataset.encoding.get("unlimited_dims", None) - - if unlimited_dims is not None: - if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): - unlimited_dims = [unlimited_dims] - else: - unlimited_dims = list(unlimited_dims) - variables, attrs = conventions.encode_dataset_coordinates(dataset) check_encoding = set() diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index c336fe7bd0d..fec70817871 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2634,6 +2634,12 @@ def test_manual_chunk(self) -> None: assert_identical(actual, auto) assert_identical(actual.load(), auto.load()) + def test_unlimited_dims_encoding_is_ignored(self) -> None: + ds = Dataset({"x": np.arange(10)}) + ds.encoding = {"unlimited_dims": ["x"]} + with self.roundtrip(ds) as actual: + assert_identical(ds, actual) + @requires_dask @pytest.mark.filterwarnings("ignore:.*does not have a Zarr V3 specification.*") def test_warning_on_bad_chunks(self) -> None: @@ -6820,8 +6826,12 @@ def test_extract_zarr_variable_encoding() -> None: def test_open_fsspec() -> None: import fsspec - if not hasattr(zarr.storage, "FSStore") or not hasattr( - zarr.storage.FSStore, "getitems" + if not ( + ( + hasattr(zarr.storage, "FSStore") + and hasattr(zarr.storage.FSStore, "getitems") + ) # zarr v2 + or hasattr(zarr.storage, "FsspecStore") # zarr v3 ): pytest.skip("zarr too old") From e82c33401d31b7c514c2ec07fcff8140fb897d40 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 15 Aug 2025 18:10:18 -0700 Subject: [PATCH 08/37] grammar --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 662ff472e94..51e4f6c6658 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,7 +28,7 @@ Deprecations Bug fixes ~~~~~~~~~ -- :py:meth:`DataTree.to_netcdf` and :py:meth:`DataTree.to_zarr` with avoid +- :py:meth:`DataTree.to_netcdf` and :py:meth:`DataTree.to_zarr` now avoid redundant computation of Dask arrays with cross-group dependencies (:issue:`10637`). By `Stephan Hoyer `_. From ca5feca47a574bb8f8184f8155e005056c34163d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 16 Aug 2025 12:23:33 -0700 Subject: [PATCH 09/37] Tweaks --- xarray/backends/netCDF4_.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 9889d0e2982..81cc2e9f2d8 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -367,6 +367,8 @@ def _build_and_get_enum( @dataclass class _Thunk: + """Pickleable equivalent of `lambda: value`.""" + value: Any def __call__(self): From cf2218824c30c9b388429cea45bf3d17d94482f9 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 19 Aug 2025 10:43:12 -0700 Subject: [PATCH 10/37] post merge fixes --- xarray/backends/api.py | 108 +++++++++++------------------------------ 1 file changed, 27 insertions(+), 81 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index b255146ce65..b31c73fa229 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, Any, - Final, Literal, TypeVar, Union, @@ -98,60 +97,31 @@ DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" -ENGINES = { - "netcdf4": backends.NetCDF4DataStore.open, - "scipy": backends.ScipyDataStore, - "pydap": backends.PydapDataStore.open, - "h5netcdf": backends.H5NetCDFStore.open, - "zarr": backends.ZarrStore.open_group, -} - - -def _get_default_engine_remote_uri() -> Literal["netcdf4", "pydap"]: - engine: Literal["netcdf4", "pydap"] - try: - import netCDF4 # noqa: F401 - - engine = "netcdf4" - except ImportError: # pragma: no cover - try: - import pydap # noqa: F401 - - engine = "pydap" - except ImportError as err: - raise ValueError( - "netCDF4 or pydap is required for accessing remote datasets via OPeNDAP" - ) from err - return engine - - -def _get_default_engine_gz() -> Literal["scipy"]: - try: - import scipy # noqa: F401 - - engine: Final = "scipy" - except ImportError as err: # pragma: no cover - raise ValueError("scipy is required for accessing .gz files") from err - return engine - -def get_default_engine_netcdf( +def get_default_netcdf_write_engine( format: T_NetcdfTypes | None, + to_file_object: bool = False, + to_memoryview: bool = False, ) -> Literal["netcdf4", "h5netcdf", "scipy"]: + """Return the default netCDF library to use for writing a netCDF file.""" engines = { "netcdf4": "netCDF4", - "scipy": "scipy.io.netcdf", + "scipy": "scipy.io", "h5netcdf": "h5netcdf", } - if format is None: - candidates = ["netcdf4", "h5netcdf", "scipy"] - elif format.upper().startswith("NETCDF3"): - candidates = ["netcdf4", "scipy"] - elif format.upper().startswith("NETCDF4"): - candidates = ["netcdf4", "h5netcdf"] - else: - raise AssertionError(f"unexpected {format=}") + candidates = list(plugins.STANDARD_BACKENDS_ORDER) + + if format is not None: + if format.upper().startswith("NETCDF3"): + candidates.remove("h5netcdf") + elif format.upper().startswith("NETCDF4"): + candidates.remove("scipy") + else: + raise ValueError(f"unexpected {format=}") + + if to_file_object: + candidates.remove("netcdf4") for engine in candidates: module_name = engines[engine] @@ -159,25 +129,13 @@ def get_default_engine_netcdf( return cast(Literal["netcdf4", "h5netcdf", "scipy"], engine) format_str = f"with {format=}" if format is not None else "" + libraries = ", ".join(engines[c] for c in candidates) raise ValueError( - f"cannot read or write NetCDF files{format_str} because none of " - f"{set(candidates)} are installed" + f"cannot write NetCDF files{format_str} because none of the suitable " + f"backend libraries ({libraries}) are installed" ) -def _get_default_engine( - path: str | None, - allow_remote: bool = False, - format: T_NetcdfTypes | None = None, -) -> T_NetcdfEngine: - if path is not None: - if allow_remote and is_remote_uri(path): - return _get_default_engine_remote_uri() # type: ignore[return-value] - if path.endswith(".gz"): - return _get_default_engine_gz() - return get_default_engine_netcdf(format) - - def _validate_dataset_names(dataset: Dataset) -> None: """DataArray.name and Dataset keys must be a string or None""" @@ -371,7 +329,7 @@ def load_dataset(filename_or_obj: T_PathFileOrDataStore, **kwargs) -> Dataset: return ds.load() -def load_dataarray(filename_or_obj: T_PathFileOrDataStore, **kwargs): +def load_dataarray(filename_or_obj: T_PathFileOrDataStore, **kwargs) -> DataArray: """Open, load into memory, and close a DataArray from a file or file-like object containing a single data variable. @@ -2083,27 +2041,15 @@ def to_netcdf( The ``multifile`` argument is only for the private use of save_mfdataset. """ - if isinstance(path_or_file, os.PathLike): - path_or_file = os.fspath(path_or_file) - if encoding is None: encoding = {} - if isinstance(path_or_file, str) or path_or_file is None: - if engine is None: - engine = _get_default_engine(path_or_file, format=format) - path_or_file = _normalize_path(path_or_file) - # writing to a file-like object - elif engine is None: - # TODO: only use 'scipy' if format is None or a netCDF3 format - engine = "scipy" - elif engine not in ("scipy", "h5netcdf"): - raise ValueError( - "invalid engine for creating bytes/memoryview or writing to a " - f"file-like object with to_netcdf: {engine!r}. Only " - "engine=None, engine='scipy' and engine='h5netcdf' is " - "supported." - ) + path_or_file = _normalize_path(path_or_file) + + if engine is None: + to_memoryview = path_or_file is None + to_file_object = not to_memoryview and not isinstance(path_or_file, str) + engine = get_default_netcdf_write_engine(format, to_file_object, to_memoryview) # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) From dc8bf9fbae470856a3bb46d93a3b864bbd4852d2 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 5 Sep 2025 18:55:31 -0700 Subject: [PATCH 11/37] Fix tests --- xarray/backends/api.py | 17 +- xarray/backends/file_manager.py | 27 +++- xarray/backends/h5netcdf_.py | 13 +- xarray/backends/netCDF4_.py | 13 +- xarray/core/datatree_io.py | 4 +- xarray/tests/test_backends.py | 46 ++---- xarray/tests/test_backends_api.py | 14 +- xarray/tests/test_backends_datatree.py | 208 ++++++++++++------------- 8 files changed, 165 insertions(+), 177 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 8302f8cd821..aa6fddd89e1 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -100,7 +100,7 @@ def get_default_netcdf_write_engine( format: T_NetcdfTypes | None, - to_fileobject_or_memoryview: bool, + to_fileobject: bool, ) -> Literal["netcdf4", "h5netcdf", "scipy"]: """Return the default netCDF library to use for writing a netCDF file.""" module_names = { @@ -119,7 +119,7 @@ def get_default_netcdf_write_engine( else: raise ValueError(f"unexpected {format=}") - if to_fileobject_or_memoryview: + if to_fileobject: candidates.remove("netcdf4") for engine in candidates: @@ -2046,8 +2046,8 @@ def to_netcdf( path_or_file = _normalize_path(path_or_file) if engine is None: - to_fileobject_or_memoryview = not isinstance(path_or_file, str) - engine = get_default_netcdf_write_engine(format, to_fileobject_or_memoryview) + to_fileobject = isinstance(path_or_file, IOBase) + engine = get_default_netcdf_write_engine(format, to_fileobject) # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) @@ -2133,15 +2133,6 @@ def dump_to_store( if encoding is None: encoding = {} - if unlimited_dims is None: - unlimited_dims = dataset.encoding.get("unlimited_dims", None) - - if unlimited_dims is not None: - if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): - unlimited_dims = [unlimited_dims] - else: - unlimited_dims = list(unlimited_dims) - variables, attrs = conventions.encode_dataset_coordinates(dataset) check_encoding = set() diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index bc8c199a913..e21df606811 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -329,6 +329,10 @@ def __hash__(self): return self.hashvalue +def _get_none(*args, **kwargs) -> None: + return None + + class PickleableFileManager(FileManager): """File manager that supports pickling by reopening a file object. @@ -357,23 +361,33 @@ def __init__( kwargs = kwargs.copy() kwargs["mode"] = mode self._file = opener(*args, **kwargs) - self._closed = False + + @property + def _closed(self) -> bool: + # If opener() raised an error in the constructor, _file may not be set + return getattr(self, "_file", None) is None + + def _assert_open(self): + if self._closed: + raise ValueError("file is closed") def acquire(self, needs_lock=True): + self._assert_open() return self._file @contextlib.contextmanager def acquire_context(self, needs_lock=True): + self._assert_open() yield self._file def close(self, needs_lock=True): if not self._closed: + assert self._file is not None self._file.close() - self._closed = True + self._file = None def __del__(self) -> None: - # If opener() raised an error in the constructor, _closed may not be set - if not getattr(self, "_closed", True): + if not self._closed: self.close() if OPTIONS["warn_for_unclosed_files"]: @@ -386,13 +400,16 @@ def __del__(self) -> None: def __getstate__(self): # file is intentionally omitted: we want to open it again - return (self._opener, self._args, self._mode, self._kwargs) + opener = _get_none if self._closed else self._opener + return (opener, self._args, self._mode, self._kwargs) def __setstate__(self, state) -> None: opener, args, mode, kwargs = state self.__init__(opener, *args, mode=mode, kwargs=kwargs) # type: ignore[misc] def __repr__(self) -> str: + if self._closed: + return f"" args_string = ", ".join(map(repr, self._args)) if self._mode is not _DEFAULT_MODE: args_string += f", mode={self._mode!r}" diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index fb420b9e2f2..077cd2ce382 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -189,7 +189,7 @@ def open( if isinstance(filename, BytesIOProxy): source = filename filename = io.BytesIO() - source.getter = filename.getbuffer + source.getvalue = filename.getbuffer if isinstance(filename, io.IOBase) and mode == "r": magic_number = read_magic_number_from_file(filename) @@ -218,7 +218,9 @@ def open( lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) manager_cls = ( - CachingFileManager if isinstance(filename, str) else PickleableFileManager + CachingFileManager + if isinstance(filename, str) and not is_remote_uri(filename) + else PickleableFileManager ) manager = manager_cls(h5netcdf.File, filename, mode=mode, kwargs=kwargs) return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) @@ -460,6 +462,11 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: filename_or_obj = _normalize_filename_or_obj(filename_or_obj) + # magic_number = ( + # bytes(filename_or_obj[:8]) + # if isinstance(filename_or_obj, bytes | memoryview) + # else try_read_magic_number_from_path(filename_or_obj) + # ) magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) if magic_number is not None: return magic_number.startswith(b"\211HDF\r\n\032\n") @@ -647,7 +654,7 @@ def open_groups_as_dict( # only warn if phony_dims exist in file # remove together with the above check # after some versions - if store.ds._phony_dim_count > 0 and emit_phony_dims_warning: + if store.ds._root._phony_dim_count > 0 and emit_phony_dims_warning: _emit_phony_dims_warning() return groups_dict diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 4c7ddb87fc0..489a5752ac8 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -384,7 +384,7 @@ class _CloseWithCopy: def __call__(self): value = self.nc4_dataset.close() - self.proxy.getter = _Thunk(value) + self.proxy.getvalue = _Thunk(value) class NetCDF4DataStore(WritableCFDataStore): @@ -462,11 +462,7 @@ def open( filename = os.fspath(filename) if not isinstance(filename, str | bytes | memoryview | BytesIOProxy): - raise TypeError( - f"invalid filename for netCDF4 backend: {filename}" - # "can only read bytes or file-like objects " - # "with engine='scipy' or 'h5netcdf'" - ) + raise TypeError(f"invalid filename for netCDF4 backend: {filename}") if format is None: format = "NETCDF4" @@ -515,6 +511,11 @@ def open( manager = PickleableFileManager( netCDF4.Dataset, "", mode=mode, kwargs=kwargs ) + # nc4_dataset = netCDF4.Dataset("", mode=mode, **kwargs) + # def close(): + # if nc4_dataset.isopen(): + # nc4_dataset.close() + # manager = DummyFileManager(nc4_dataset, close=close) else: manager = CachingFileManager( netCDF4.Dataset, filename, mode=mode, kwargs=kwargs diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index 9151b90d015..06d6f4278da 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -55,10 +55,10 @@ def _datatree_to_netcdf( filepath = _normalize_path(filepath) if engine is None: - to_fileobject_or_memoryview = not isinstance(filepath, str) + to_fileobject = isinstance(filepath, io.IOBase) engine = get_default_netcdf_write_engine( format="NETCDF4", # required for supporting groups - to_fileobject_or_memoryview=to_fileobject_or_memoryview, + to_fileobject=to_fileobject, ) # type: ignore[assignment] if group is not None: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 23a1758e23d..2562f09fe2c 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2523,6 +2523,14 @@ def test_pickle_open_dataset_from_bytes(self) -> None: assert_identical(unpickled, original) unpickled.close() + def test_compute_false(self) -> None: + original = create_test_data() + with pytest.raises( + NotImplementedError, + match=re.escape("to_netcdf() with compute=False is not yet implemented"), + ): + original.to_netcdf(engine=self.engine, compute=False) + class InMemoryNetCDFWithGroups(InMemoryNetCDF): def test_roundtrip_group_via_memoryview(self) -> None: @@ -4595,30 +4603,6 @@ def roundtrip( async def test_load_async(self) -> None: await super().test_load_async() - def test_to_netcdf_explicit_engine(self) -> None: - Dataset({"foo": 42}).to_netcdf(engine="scipy") - - def test_roundtrip_via_bytes(self) -> None: - original = create_test_data() - netcdf_bytes = original.to_netcdf(engine="scipy") - roundtrip = open_dataset(netcdf_bytes, engine="scipy") - assert_identical(roundtrip, original) - - def test_to_bytes_compute_false(self) -> None: - original = create_test_data() - with pytest.raises( - NotImplementedError, - match=re.escape("to_netcdf() with compute=False is not yet implemented"), - ): - original.to_netcdf(engine="scipy", compute=False) - - def test_bytes_pickle(self) -> None: - data = Dataset({"foo": ("x", [1, 2, 3])}) - fobj = data.to_netcdf(engine="scipy") - with self.open(fobj) as ds: - unpickled = pickle.loads(pickle.dumps(ds)) - assert_identical(unpickled, data) - @requires_scipy class TestScipyFileObject(CFEncodedBase, NetCDF3Only, FileObjectNetCDF): @@ -4646,6 +4630,11 @@ def roundtrip( with self.open(f, **open_kwargs) as ds: yield ds + @pytest.mark.asyncio + @pytest.mark.skip(reason="NetCDF backends don't support async loading") + async def test_load_async(self) -> None: + await super().test_load_async() + @pytest.mark.xfail(reason="not working yet") def test_open_twice(self): super().test_open_twice() @@ -4775,14 +4764,6 @@ def test_engine(self) -> None: with pytest.raises(ValueError, match=r"unrecognized engine"): data.to_netcdf("foo.nc", engine="foobar") # type: ignore[call-overload] - with pytest.raises( - ValueError, - match=re.escape( - "can only read bytes or file-like objects with engine='scipy' or 'h5netcdf'" - ), - ): - data.to_netcdf(engine="netcdf4") - with create_tmp_file() as tmp_file: data.to_netcdf(tmp_file) with pytest.raises(ValueError, match=r"unrecognized engine"): @@ -5224,6 +5205,7 @@ def test_write_inconsistent_chunks(self) -> None: assert actual["y"].encoding["chunksizes"] == (100, 50) +@network @requires_h5netcdf_ros3 class TestH5NetCDFDataRos3Driver(TestCommon): engine: T_NetcdfEngine = "h5netcdf" diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index 2d659dcb9c9..2d0475ef31a 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -23,23 +23,17 @@ @requires_scipy @requires_h5netcdf def test_get_default_netcdf_write_engine() -> None: - engine = get_default_netcdf_write_engine( - format=None, to_fileobject_or_memoryview=False - ) + engine = get_default_netcdf_write_engine(format=None, to_fileobject=False) assert engine == "netcdf4" - engine = get_default_netcdf_write_engine( - format="NETCDF4", to_fileobject_or_memoryview=False - ) + engine = get_default_netcdf_write_engine(format="NETCDF4", to_fileobject=False) assert engine == "netcdf4" - engine = get_default_netcdf_write_engine( - format="NETCDF4", to_fileobject_or_memoryview=True - ) + engine = get_default_netcdf_write_engine(format="NETCDF4", to_fileobject=True) assert engine == "h5netcdf" engine = get_default_netcdf_write_engine( - format="NETCDF3_CLASSIC", to_fileobject_or_memoryview=True + format="NETCDF3_CLASSIC", to_fileobject=True ) assert engine == "scipy" diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 100b05fbcc7..7ec329900d4 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -321,42 +321,12 @@ def test_default_write_engine(self, tmpdir, simple_datatree, monkeypatch): original_dt = simple_datatree original_dt.to_netcdf(filepath) # should not raise - def test_roundtrip_via_memoryview_engine_specified(self, simple_datatree): - original_dt = simple_datatree - roundtrip_dt = load_datatree( - original_dt.to_netcdf(engine=self.engine), engine=self.engine - ) - assert_equal(original_dt, roundtrip_dt) - - -@requires_h5netcdf_or_netCDF4 -class TestGenericNetCDFIO: - def test_roundtrip_via_memoryview(self, simple_datatree): - original_dt = simple_datatree - roundtrip_dt = load_datatree(original_dt.to_netcdf()) - assert_equal(original_dt, roundtrip_dt) - -@requires_netCDF4 -class TestNetCDF4DatatreeIO(DatatreeIOBase): - engine: T_DataTreeNetcdfEngine | None = "netcdf4" - - def test_open_datatree(self, unaligned_datatree_nc) -> None: - """Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy.""" - - with pytest.raises( - ValueError, - match=( - re.escape( - "group '/Group1/subgroup1' is not aligned with its parents:\nGroup:\n" - ) - + ".*" - ), - ): - open_datatree(unaligned_datatree_nc) +class NetCDFIOBase(DatatreeIOBase): + engine: T_DataTreeNetcdfEngine | None @requires_dask - def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None: + def test_open_datatree_chunks(self, tmpdir) -> None: filepath = tmpdir / "test.nc" chunks = {"x": 2, "y": 1} @@ -371,13 +341,65 @@ def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None: "/group2": set2_data.chunk(chunks), } ) - original_tree.to_netcdf(filepath, engine="netcdf4") + original_tree.to_netcdf(filepath, engine=self.engine) - with open_datatree(filepath, engine="netcdf4", chunks=chunks) as tree: + with open_datatree(filepath, engine=self.engine, chunks=chunks) as tree: xr.testing.assert_identical(tree, original_tree) assert_chunks_equal(tree, original_tree, enforce_dask=True) + # def test_roundtrip_via_memoryview(self, simple_datatree) -> None: + # original_dt = simple_datatree + # memview = original_dt.to_netcdf(engine=self.engine) + # roundtrip_dt = load_datatree(memview, engine=self.engine) + # assert_equal(original_dt, roundtrip_dt) + + def test_to_bytes_compute_false(self, simple_datatree) -> None: + original_dt = simple_datatree + with pytest.raises( + NotImplementedError, + match=re.escape("to_netcdf() with compute=False is not yet implemented"), + ): + original_dt.to_netcdf(engine=self.engine, compute=False) + + def test_open_datatree_specific_group(self, tmpdir, simple_datatree) -> None: + """Test opening a specific group within a NetCDF file using `open_datatree`.""" + filepath = tmpdir / "test.nc" + group = "/set1" + original_dt = simple_datatree + original_dt.to_netcdf(filepath, engine=self.engine) + expected_subtree = original_dt[group].copy() + expected_subtree.orphan() + with open_datatree(filepath, group=group, engine=self.engine) as subgroup_tree: + assert subgroup_tree.root.parent is None + assert_equal(subgroup_tree, expected_subtree) + + +@requires_h5netcdf_or_netCDF4 +class TestGenericNetCDFIO(NetCDFIOBase): + engine: T_DataTreeNetcdfEngine | None = None + + # def test_cross_engine_roundtrip_via_memoryview(self, simple_datatree) -> None: + # original_dt = simple_datatree + # memview = original_dt.to_netcdf(engine='h5netcdf') + # roundtrip_dt = load_datatree(memview, engine='') + # # del memview + # assert_equal(original_dt, roundtrip_dt) + + def test_open_datatree(self, unaligned_datatree_nc) -> None: + """Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy.""" + + with pytest.raises( + ValueError, + match=( + re.escape( + "group '/Group1/subgroup1' is not aligned with its parents:\nGroup:\n" + ) + + ".*" + ), + ): + open_datatree(unaligned_datatree_nc) + def test_open_groups(self, unaligned_datatree_nc) -> None: """Test `open_groups` with a netCDF4 file with an unaligned group hierarchy.""" unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc) @@ -420,7 +442,7 @@ def test_open_groups_chunks(self, tmpdir) -> None: ) original_tree.to_netcdf(filepath, mode="w") - dict_of_datasets = open_groups(filepath, engine="netcdf4", chunks=chunks) + dict_of_datasets = open_groups(filepath, chunks=chunks) for path, ds in dict_of_datasets.items(): assert {k: max(vs) for k, vs in ds.chunksizes.items()} == chunks, ( @@ -430,6 +452,11 @@ def test_open_groups_chunks(self, tmpdir) -> None: for ds in dict_of_datasets.values(): ds.close() + +@requires_netCDF4 +class TestNetCDF4DatatreeIO(NetCDFIOBase): + engine: T_DataTreeNetcdfEngine | None = "netcdf4" + def test_open_groups_to_dict(self, tmpdir) -> None: """Create an aligned netCDF4 with the following structure to test `open_groups` and `DataTree.from_dict`. @@ -477,17 +504,45 @@ def test_open_groups_to_dict(self, tmpdir) -> None: for ds in aligned_dict_of_datasets.values(): ds.close() - def test_open_datatree_specific_group(self, tmpdir, simple_datatree) -> None: - """Test opening a specific group within a NetCDF file using `open_datatree`.""" - filepath = tmpdir / "test.nc" - group = "/set1" + +@requires_h5netcdf +class TestH5NetCDFDatatreeIO(NetCDFIOBase): + engine: T_DataTreeNetcdfEngine | None = "h5netcdf" + + def test_phony_dims_warning(self, tmpdir) -> None: + filepath = tmpdir + "/phony_dims.nc" + import h5py + + foo_data = np.arange(125).reshape(5, 5, 5) + bar_data = np.arange(625).reshape(25, 5, 5) + var = {"foo1": foo_data, "foo2": bar_data, "foo3": foo_data, "foo4": bar_data} + with h5py.File(filepath, "w") as f: + grps = ["bar", "baz"] + for grp in grps: + fx = f.create_group(grp) + for k, v in var.items(): + fx.create_dataset(k, data=v) + + with pytest.warns(UserWarning, match="The 'phony_dims' kwarg"): + with open_datatree(filepath, engine=self.engine) as tree: + assert tree.bar.dims == { + "phony_dim_0": 5, + "phony_dim_1": 5, + "phony_dim_2": 5, + "phony_dim_3": 25, + } + + def test_roundtrip_using_filelike_object(self, tmpdir, simple_datatree) -> None: original_dt = simple_datatree - original_dt.to_netcdf(filepath) - expected_subtree = original_dt[group].copy() - expected_subtree.orphan() - with open_datatree(filepath, group=group, engine=self.engine) as subgroup_tree: - assert subgroup_tree.root.parent is None - assert_equal(subgroup_tree, expected_subtree) + filepath = tmpdir + "/test.nc" + # h5py requires both read and write access when writing, it will + # work with file-like objects provided they support both, and are + # seekable. + with open(filepath, "wb+") as file: + original_dt.to_netcdf(file, engine=self.engine) + with open(filepath, "rb") as file: + with open_datatree(file, engine=self.engine) as roundtrip_dt: + assert_equal(original_dt, roundtrip_dt) @network @@ -583,64 +638,6 @@ def test_open_groups_to_dict(self, url=all_aligned_child_nodes_url) -> None: assert opened_tree.identical(aligned_dt) -@requires_h5netcdf -class TestH5NetCDFDatatreeIO(DatatreeIOBase): - engine: T_DataTreeNetcdfEngine | None = "h5netcdf" - - def test_phony_dims_warning(self, tmpdir) -> None: - filepath = tmpdir + "/phony_dims.nc" - import h5py - - foo_data = np.arange(125).reshape(5, 5, 5) - bar_data = np.arange(625).reshape(25, 5, 5) - var = {"foo1": foo_data, "foo2": bar_data, "foo3": foo_data, "foo4": bar_data} - with h5py.File(filepath, "w") as f: - grps = ["bar", "baz"] - for grp in grps: - fx = f.create_group(grp) - for k, v in var.items(): - fx.create_dataset(k, data=v) - - with pytest.warns(UserWarning, match="The 'phony_dims' kwarg"): - with open_datatree(filepath, engine=self.engine) as tree: - assert tree.bar.dims == { - "phony_dim_0": 5, - "phony_dim_1": 5, - "phony_dim_2": 5, - "phony_dim_3": 25, - } - - def test_roundtrip_via_bytes(self, simple_datatree) -> None: - original_dt = simple_datatree - roundtrip_dt = open_datatree(original_dt.to_netcdf()) - assert_equal(original_dt, roundtrip_dt) - - def test_roundtrip_via_bytes_engine_specified(self, simple_datatree) -> None: - original_dt = simple_datatree - roundtrip_dt = open_datatree(original_dt.to_netcdf(engine=self.engine)) - assert_equal(original_dt, roundtrip_dt) - - def test_to_bytes_compute_false(self, simple_datatree) -> None: - original_dt = simple_datatree - with pytest.raises( - NotImplementedError, - match=re.escape("to_netcdf() with compute=False is not yet implemented"), - ): - original_dt.to_netcdf(compute=False) - - def test_roundtrip_using_filelike_object(self, tmpdir, simple_datatree) -> None: - original_dt = simple_datatree - filepath = tmpdir + "/test.nc" - # h5py requires both read and write access when writing, it will - # work with file-like objects provided they support both, and are - # seekable. - with open(filepath, "wb+") as file: - original_dt.to_netcdf(file, engine=self.engine) - with open(filepath, "rb") as file: - with open_datatree(file, engine=self.engine) as roundtrip_dt: - assert_equal(original_dt, roundtrip_dt) - - @requires_zarr @parametrize_zarr_format class TestZarrDatatreeIO: @@ -814,11 +811,10 @@ def assert_expected_zarr_files_exist( in_progress_dt = load_datatree(str(storepath), engine="zarr") assert not in_progress_dt.equals(original_dt) - result.compute() # type: ignore[union-attr] + result.compute() written_dt = load_datatree(str(storepath), engine="zarr") assert_identical(written_dt, original_dt) - @requires_dask @requires_dask def test_rplus_mode( self, tmp_path: Path, simple_datatree: DataTree, zarr_format: Literal[2, 3] From d6b32cbb9a94db7ce5bb6c10315a37a09172d846 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 5 Sep 2025 19:06:49 -0700 Subject: [PATCH 12/37] More tests --- xarray/backends/netCDF4_.py | 5 ----- xarray/tests/test_backends.py | 24 +++++++++++++++++------ xarray/tests/test_backends_datatree.py | 27 ++++++++++++++++++++------ 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 489a5752ac8..b0eb8216c25 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -511,11 +511,6 @@ def open( manager = PickleableFileManager( netCDF4.Dataset, "", mode=mode, kwargs=kwargs ) - # nc4_dataset = netCDF4.Dataset("", mode=mode, **kwargs) - # def close(): - # if nc4_dataset.isopen(): - # nc4_dataset.close() - # manager = DummyFileManager(nc4_dataset, close=close) else: manager = CachingFileManager( netCDF4.Dataset, filename, mode=mode, kwargs=kwargs diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 2562f09fe2c..48343fe61ea 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2503,12 +2503,6 @@ def test_roundtrip(self) -> None: roundtrip = load_dataset(result, engine=self.engine) assert_identical(roundtrip, original) - def test_roundtrip_via_memoryview(self) -> None: - original = create_test_data() - result = memoryview(original.to_netcdf(engine=self.engine)) - roundtrip = load_dataset(result, engine=self.engine) - assert_identical(roundtrip, original) - def test_roundtrip_via_bytes(self) -> None: original = create_test_data() result = bytes(original.to_netcdf(engine=self.engine)) @@ -5205,6 +5199,24 @@ def test_write_inconsistent_chunks(self) -> None: assert actual["y"].encoding["chunksizes"] == (100, 50) +@requires_netCDF4 +@requires_h5netcdf +def test_memoryview_write_h5netcdf_read_netcdf4() -> None: + original = create_test_data() + result = original.to_netcdf(engine="h5netcdf") + roundtrip = load_dataset(result, engine="netcdf4") + assert_identical(roundtrip, original) + + +@requires_netCDF4 +@requires_h5netcdf +def test_memoryview_write_netcdf4_read_h5netcdf() -> None: + original = create_test_data() + result = original.to_netcdf(engine="netcdf4") + roundtrip = load_dataset(result, engine="h5netcdf") + assert_identical(roundtrip, original) + + @network @requires_h5netcdf_ros3 class TestH5NetCDFDataRos3Driver(TestCommon): diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 7ec329900d4..72e7e3ee77e 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -379,12 +379,27 @@ def test_open_datatree_specific_group(self, tmpdir, simple_datatree) -> None: class TestGenericNetCDFIO(NetCDFIOBase): engine: T_DataTreeNetcdfEngine | None = None - # def test_cross_engine_roundtrip_via_memoryview(self, simple_datatree) -> None: - # original_dt = simple_datatree - # memview = original_dt.to_netcdf(engine='h5netcdf') - # roundtrip_dt = load_datatree(memview, engine='') - # # del memview - # assert_equal(original_dt, roundtrip_dt) + @requires_h5netcdf + @requires_netCDF4 + def test_memoryview_write_h5netcdf_read_netcdf4(self, simple_datatree) -> None: + # This test triggers a warning from pytype, but does not fail: + # PytestUnraisableExceptionWarning: Exception ignored in: <_io.BytesIO object at 0x32ce0c540> + # BufferError: Existing exports of data: object cannot be re-sized + # The warning is silenced if _either_ memview is converted into bytes or + # the use of PickleableFileManager inside NetCDF4DataStore.open() is + # replaced by DummyFileStore. shoyer suspects a netCDF4-python bug. + original_dt = simple_datatree + memview = original_dt.to_netcdf(engine="h5netcdf") + roundtrip_dt = load_datatree(memview, engine="netcdf4") + assert_equal(original_dt, roundtrip_dt) + + @requires_h5netcdf + @requires_netCDF4 + def test_memoryview_write_netcdf4_read_h5netcdf(self, simple_datatree) -> None: + original_dt = simple_datatree + memview = original_dt.to_netcdf(engine="netcdf4") + roundtrip_dt = load_datatree(memview, engine="h5netcdf") + assert_equal(original_dt, roundtrip_dt) def test_open_datatree(self, unaligned_datatree_nc) -> None: """Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy.""" From 0311987367d79cc56484ff29f0083a478a87dfd3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 7 Sep 2025 20:22:31 -0700 Subject: [PATCH 13/37] Fix bug and add type annotations --- xarray/backends/file_manager.py | 150 ++++++++++++--------- xarray/backends/locks.py | 34 ++--- xarray/core/types.py | 11 ++ xarray/tests/test_backends_api.py | 11 +- xarray/tests/test_backends_datatree.py | 6 - xarray/tests/test_backends_file_manager.py | 11 ++ 6 files changed, 126 insertions(+), 97 deletions(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index e21df606811..d269b1bfb45 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -1,31 +1,33 @@ from __future__ import annotations import atexit -import contextlib -import io import threading import uuid import warnings -from collections.abc import Hashable, Mapping -from typing import Any +from collections.abc import Callable, Hashable, Iterator, Mapping, MutableMapping +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Generic, Literal, TypeVar, cast from xarray.backends.locks import acquire from xarray.backends.lru_cache import LRUCache from xarray.core import utils from xarray.core.options import OPTIONS +from xarray.core.types import Closable, Lock # Global cache for storing open files. -FILE_CACHE: LRUCache[Any, io.IOBase] = LRUCache( +FILE_CACHE: LRUCache[Any, Closable] = LRUCache( maxsize=OPTIONS["file_cache_maxsize"], on_evict=lambda k, v: v.close() ) assert FILE_CACHE.maxsize, "file cache must be at least size one" +T_File = TypeVar("T_File", bound=Closable) + REF_COUNTS: dict[Any, int] = {} -_DEFAULT_MODE = utils.ReprObject("") +_OMIT_MODE = utils.ReprObject("") -class FileManager: +class FileManager(Generic[T_File]): """Manager for acquiring and closing a file object. Use FileManager subclasses (CachingFileManager in particular) on backend @@ -33,11 +35,13 @@ class FileManager: many open files and transferring them between multiple processes. """ - def acquire(self, needs_lock=True): + def acquire(self, needs_lock: bool = True) -> T_File: """Acquire the file object from this manager.""" raise NotImplementedError() - def acquire_context(self, needs_lock=True): + def acquire_context( + self, needs_lock: bool = True + ) -> AbstractContextManager[T_File]: """Context manager for acquiring a file. Yields a file object. The context manager unwinds any actions taken as part of acquisition @@ -46,12 +50,12 @@ def acquire_context(self, needs_lock=True): """ raise NotImplementedError() - def close(self, needs_lock=True): + def close(self, needs_lock: bool = True) -> None: """Close the file object associated with this manager, if needed.""" raise NotImplementedError() -class CachingFileManager(FileManager): +class CachingFileManager(FileManager[T_File]): """Wrapper for automatically opening and closing file objects. Unlike files, CachingFileManager objects can be safely pickled and passed @@ -81,14 +85,14 @@ class CachingFileManager(FileManager): def __init__( self, - opener, + opener: Callable[..., T_File], *args: Any, - mode: Any = _DEFAULT_MODE, + mode: Any = _OMIT_MODE, kwargs: Mapping[str, Any] | None = None, - lock=None, - cache=None, + lock: Lock | None | Literal[False] = None, + cache: MutableMapping[Any, T_File] | None = None, manager_id: Hashable | None = None, - ref_counts=None, + ref_counts: dict[Any, int] | None = None, ): """Initialize a CachingFileManager. @@ -134,13 +138,17 @@ def __init__( self._mode = mode self._kwargs = {} if kwargs is None else dict(kwargs) - self._use_default_lock = lock is None or lock is False - self._lock = threading.Lock() if self._use_default_lock else lock + if lock is None or lock is False: + self._use_default_lock = True + self._lock: Lock = threading.Lock() + else: + self._use_default_lock = False + self._lock = lock # cache[self._key] stores the file associated with this object. if cache is None: - cache = FILE_CACHE - self._cache = cache + cache = cast(MutableMapping[Any, T_File], FILE_CACHE) + self._cache: MutableMapping[Any, T_File] = cache if manager_id is None: # Each call to CachingFileManager should separately open files. manager_id = str(uuid.uuid4()) @@ -155,7 +163,7 @@ def __init__( self._ref_counter = _RefCounter(ref_counts) self._ref_counter.increment(self._key) - def _make_key(self): + def _make_key(self) -> _HashedSequence: """Make a key for caching files in the LRU cache.""" value = ( self._opener, @@ -166,8 +174,8 @@ def _make_key(self): ) return _HashedSequence(value) - @contextlib.contextmanager - def _optional_lock(self, needs_lock): + @contextmanager + def _optional_lock(self, needs_lock: bool): """Context manager for optionally acquiring a lock.""" if needs_lock: with self._lock: @@ -175,7 +183,7 @@ def _optional_lock(self, needs_lock): else: yield - def acquire(self, needs_lock=True): + def acquire(self, needs_lock: bool = True) -> T_File: """Acquire a file object from the manager. A new file is only opened if it has expired from the @@ -193,8 +201,8 @@ def acquire(self, needs_lock=True): file, _ = self._acquire_with_cache_info(needs_lock) return file - @contextlib.contextmanager - def acquire_context(self, needs_lock=True): + @contextmanager + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_File]: """Context manager for acquiring a file.""" file, cached = self._acquire_with_cache_info(needs_lock) try: @@ -204,14 +212,14 @@ def acquire_context(self, needs_lock=True): self.close(needs_lock) raise - def _acquire_with_cache_info(self, needs_lock=True): + def _acquire_with_cache_info(self, needs_lock: bool = True) -> tuple[T_File, bool]: """Acquire a file, returning the file and whether it was cached.""" with self._optional_lock(needs_lock): try: file = self._cache[self._key] except KeyError: kwargs = self._kwargs - if self._mode is not _DEFAULT_MODE: + if self._mode is not _OMIT_MODE: kwargs = kwargs.copy() kwargs["mode"] = self._mode file = self._opener(*self._args, **kwargs) @@ -223,7 +231,7 @@ def _acquire_with_cache_info(self, needs_lock=True): else: return file, True - def close(self, needs_lock=True): + def close(self, needs_lock: bool = True) -> None: """Explicitly close any associated file object (if necessary).""" # TODO: remove needs_lock if/when we have a reentrant lock in # dask.distributed: https://github.com/dask/dask/issues/3832 @@ -282,7 +290,7 @@ def __setstate__(self, state) -> None: def __repr__(self) -> str: args_string = ", ".join(map(repr, self._args)) - if self._mode is not _DEFAULT_MODE: + if self._mode is not _OMIT_MODE: args_string += f", mode={self._mode!r}" return ( f"{type(self).__name__}({self._opener!r}, {args_string}, " @@ -325,15 +333,15 @@ def __init__(self, tuple_value): self[:] = tuple_value self.hashvalue = hash(tuple_value) - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return self.hashvalue -def _get_none(*args, **kwargs) -> None: +def _get_none() -> None: return None -class PickleableFileManager(FileManager): +class PickleableFileManager(FileManager[T_File]): """File manager that supports pickling by reopening a file object. Use PickleableFileManager for wrapping file-like objects that do not natively @@ -344,10 +352,10 @@ class PickleableFileManager(FileManager): def __init__( self, - opener, - *args, - mode=_DEFAULT_MODE, - kwargs=None, + opener: Callable[..., T_File], + *args: Any, + mode: Any = _OMIT_MODE, + kwargs: Mapping[str, Any] | None = None, ): kwargs = {} if kwargs is None else dict(kwargs) self._opener = opener @@ -357,34 +365,43 @@ def __init__( # Note: No need for locking with PickleableFileManager, because all # opening of files happens in the constructor. - if mode is not _DEFAULT_MODE: + if mode is not _OMIT_MODE: kwargs = kwargs.copy() kwargs["mode"] = mode - self._file = opener(*args, **kwargs) + self._file: T_File | None = opener(*args, **kwargs) @property def _closed(self) -> bool: # If opener() raised an error in the constructor, _file may not be set return getattr(self, "_file", None) is None - def _assert_open(self): + def _get_unclosed_file(self) -> T_File: if self._closed: - raise ValueError("file is closed") + raise RuntimeError("file is closed") + file = self._file + assert file is not None + return file - def acquire(self, needs_lock=True): - self._assert_open() - return self._file + def acquire(self, needs_lock: bool = True) -> T_File: + del needs_lock # unused + return self._get_unclosed_file() - @contextlib.contextmanager - def acquire_context(self, needs_lock=True): - self._assert_open() - yield self._file + @contextmanager + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_File]: + del needs_lock # unused + yield self._get_unclosed_file() - def close(self, needs_lock=True): + def close(self, needs_lock: bool = True) -> None: + del needs_lock # unused if not self._closed: - assert self._file is not None - self._file.close() + file = self._get_unclosed_file() + file.close() self._file = None + # Remove all references to opener arguments, so they can be garbage + # collected. + self._args = () + self._mode = _OMIT_MODE + self._kwargs = {} def __del__(self) -> None: if not self._closed: @@ -411,15 +428,14 @@ def __repr__(self) -> str: if self._closed: return f"" args_string = ", ".join(map(repr, self._args)) - if self._mode is not _DEFAULT_MODE: + if self._mode is not _OMIT_MODE: args_string += f", mode={self._mode!r}" - if "memory" in self._kwargs: - kwargs = self._kwargs | {"memory": utils.ReprObject("...")} - else: - kwargs = self._kwargs - return ( - f"{type(self).__name__}({self._opener!r}, {args_string}, kwargs={kwargs})" + kwargs = ( + self._kwargs | {"memory": utils.ReprObject("...")} + if "memory" in self._kwargs + else self._kwargs ) + return f"{type(self).__name__}({self._opener!r}, {args_string}, {kwargs=})" @atexit.register @@ -430,24 +446,24 @@ def _remove_del_methods(): del PickleableFileManager.__del__ -class DummyFileManager(FileManager): +class DummyFileManager(FileManager[T_File]): """FileManager that simply wraps an open file in the FileManager interface.""" - def __init__(self, value, *, close=None): + def __init__(self, value: T_File, *, close: Callable[[], None] | None = None): if close is None: close = value.close self._value = value self._close = close - def acquire(self, needs_lock=True): - del needs_lock # ignored + def acquire(self, needs_lock: bool = True) -> T_File: + del needs_lock # unused return self._value - @contextlib.contextmanager - def acquire_context(self, needs_lock=True): - del needs_lock + @contextmanager + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_File]: + del needs_lock # unused yield self._value - def close(self, needs_lock=True): - del needs_lock # ignored + def close(self, needs_lock: bool = True) -> None: + del needs_lock # unused self._close() diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 82d3e0b7dae..784443544ee 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -4,15 +4,17 @@ import threading import uuid import weakref -from collections.abc import Hashable, MutableMapping -from typing import Any, ClassVar +from collections.abc import Callable, Hashable, MutableMapping, Sequence +from typing import Any, ClassVar, Literal from weakref import WeakValueDictionary +from xarray.core.types import Lock + # SerializableLock is adapted from Dask: # https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224 # Used under the terms of Dask's license, see licenses/DASK_LICENSE. -class SerializableLock: +class SerializableLock(Lock): """A Serializable per-process Lock This wraps a normal ``threading.Lock`` object and satisfies the same @@ -90,7 +92,7 @@ def __str__(self): _FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary() -def _get_threaded_lock(key): +def _get_threaded_lock(key: str) -> threading.Lock: try: lock = _FILE_LOCKS[key] except KeyError: @@ -98,14 +100,14 @@ def _get_threaded_lock(key): return lock -def _get_multiprocessing_lock(key): +def _get_multiprocessing_lock(key: str) -> Lock: # TODO: make use of the key -- maybe use locket.py? # https://github.com/mwilliamson/locket.py del key # unused return multiprocessing.Lock() -def _get_lock_maker(scheduler=None): +def _get_lock_maker(scheduler: str | None = None) -> Callable[..., Lock]: """Returns an appropriate function for creating resource locks. Parameters @@ -125,10 +127,8 @@ def _get_lock_maker(scheduler=None): elif scheduler == "distributed": # Lazy import distributed since it is can add a significant # amount of time to import - try: - from dask.distributed import Lock as DistributedLock - except ImportError: - DistributedLock = None + from dask.distributed import Lock as DistributedLock + return DistributedLock else: raise KeyError(scheduler) @@ -172,7 +172,7 @@ def get_dask_scheduler(get=None, collection=None) -> str | None: return "threaded" -def get_write_lock(key): +def get_write_lock(key: str) -> Lock: """Get a scheduler appropriate lock for writing to the given resource. Parameters @@ -207,14 +207,14 @@ def acquire(lock, blocking=True): return lock.acquire(blocking) -class CombinedLock: +class CombinedLock(Lock): """A combination of multiple locks. Like a locked door, a CombinedLock is locked if any of its constituent locks are locked. """ - def __init__(self, locks): + def __init__(self, locks: Sequence[Lock]): self.locks = tuple(set(locks)) # remove duplicates def acquire(self, blocking=True): @@ -239,7 +239,7 @@ def __repr__(self): return f"CombinedLock({list(self.locks)!r})" -class DummyLock: +class DummyLock(Lock): """DummyLock provides the lock API without any actual locking.""" def acquire(self, blocking=True): @@ -258,9 +258,9 @@ def locked(self): return False -def combine_locks(locks): +def combine_locks(locks: Sequence[Lock]) -> Lock: """Combine a sequence of locks into a single lock.""" - all_locks = [] + all_locks: list[Lock] = [] for lock in locks: if isinstance(lock, CombinedLock): all_locks.extend(lock.locks) @@ -276,7 +276,7 @@ def combine_locks(locks): return DummyLock() -def ensure_lock(lock): +def ensure_lock(lock: Lock | None | Literal[False]) -> Lock: """Ensure that the given object is a lock.""" if lock is None or lock is False: return DummyLock() diff --git a/xarray/core/types.py b/xarray/core/types.py index 2305ce56199..a0d62d30c9f 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -364,3 +364,14 @@ def read(self, n: int = ..., /) -> AnyStr_co: ] ResampleCompatible: TypeAlias = str | datetime.timedelta | pd.Timedelta | pd.DateOffset + + +class Closable(Protocol): + def close(self) -> None: ... + + +class Lock(Protocol): + def acquire(self, *args, **kwargs) -> Any: ... + def release(self) -> None: ... + def __enter__(self) -> Any: ... + def __exit__(self, *args, **kwargs) -> None: ... diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index 2d0475ef31a..b222c537ee4 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -46,20 +46,17 @@ def test_default_engine_h5netcdf(monkeypatch): monkeypatch.delitem(sys.modules, "scipy", raising=False) monkeypatch.setattr(sys, "meta_path", []) - engine = get_default_netcdf_write_engine( - format=None, to_fileobject_or_memoryview=False - ) + engine = get_default_netcdf_write_engine(format=None, to_fileobject=False) assert engine == "h5netcdf" with pytest.raises( ValueError, match=re.escape( - "cannot write NetCDF files with format='NETCDF3_CLASSIC' because none of the suitable backend libraries (netCDF4, scipy) are installed" + "cannot write NetCDF files with format='NETCDF3_CLASSIC' because " + "none of the suitable backend libraries (netCDF4, scipy) are installed" ), ): - get_default_netcdf_write_engine( - format="NETCDF3_CLASSIC", to_fileobject_or_memoryview=False - ) + get_default_netcdf_write_engine(format="NETCDF3_CLASSIC", to_fileobject=False) def test_custom_engine() -> None: diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 72e7e3ee77e..667e7323758 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -382,12 +382,6 @@ class TestGenericNetCDFIO(NetCDFIOBase): @requires_h5netcdf @requires_netCDF4 def test_memoryview_write_h5netcdf_read_netcdf4(self, simple_datatree) -> None: - # This test triggers a warning from pytype, but does not fail: - # PytestUnraisableExceptionWarning: Exception ignored in: <_io.BytesIO object at 0x32ce0c540> - # BufferError: Existing exports of data: object cannot be re-sized - # The warning is silenced if _either_ memview is converted into bytes or - # the use of PickleableFileManager inside NetCDF4DataStore.open() is - # replaced by DummyFileStore. shoyer suspects a netCDF4-python bug. original_dt = simple_datatree memview = original_dt.to_netcdf(engine="h5netcdf") roundtrip_dt = load_datatree(memview, engine="netcdf4") diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index 97b65f876b9..7b5fbee6309 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -278,3 +278,14 @@ def test_pickleable_file_manager_write_pickle(tmpdir) -> None: with open(path) as f: assert f.read() == "foobar" + + +def test_pickleable_file_manager_preserves_closed(tmpdir) -> None: + path = str(tmpdir.join("testing.txt")) + manager = PickleableFileManager(open, path, mode="w") + f = manager.acquire() + f.write("foo") + manager.close() + manager2 = pickle.loads(pickle.dumps(manager)) + assert manager2._closed + assert repr(manager2) == "" From 292d63af37aec3b914da3e43adfdf9a1401f205a Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Sep 2025 00:07:43 -0400 Subject: [PATCH 14/37] tweak tests --- xarray/backends/file_manager.py | 5 +-- xarray/backends/h5netcdf_.py | 5 --- xarray/backends/locks.py | 2 +- xarray/backends/netCDF4_.py | 6 +++ xarray/backends/scipy_.py | 11 +++--- xarray/core/datatree_io.py | 2 +- xarray/tests/test_backends.py | 53 +++++++++++++------------- xarray/tests/test_backends_datatree.py | 15 ++++---- 8 files changed, 50 insertions(+), 49 deletions(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index d269b1bfb45..f7cd4675729 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -365,9 +365,8 @@ def __init__( # Note: No need for locking with PickleableFileManager, because all # opening of files happens in the constructor. - if mode is not _OMIT_MODE: - kwargs = kwargs.copy() - kwargs["mode"] = mode + if mode != _OMIT_MODE: + kwargs = kwargs | {"mode": mode} self._file: T_File | None = opener(*args, **kwargs) @property diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 077cd2ce382..422eadc6c34 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -462,11 +462,6 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: filename_or_obj = _normalize_filename_or_obj(filename_or_obj) - # magic_number = ( - # bytes(filename_or_obj[:8]) - # if isinstance(filename_or_obj, bytes | memoryview) - # else try_read_magic_number_from_path(filename_or_obj) - # ) magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) if magic_number is not None: return magic_number.startswith(b"\211HDF\r\n\032\n") diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 784443544ee..2424d8f6fa9 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -100,7 +100,7 @@ def _get_threaded_lock(key: str) -> threading.Lock: return lock -def _get_multiprocessing_lock(key: str) -> Lock: +def _get_multiprocessing_lock(key: str) -> multiprocessing.Lock: # TODO: make use of the key -- maybe use locket.py? # https://github.com/mwilliamson/locket.py del key # unused diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b0eb8216c25..234768ef891 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -6,6 +6,7 @@ from collections.abc import Iterable from contextlib import suppress from dataclasses import dataclass +from io import IOBase from typing import TYPE_CHECKING, Any, Self import numpy as np @@ -461,6 +462,11 @@ def open( if isinstance(filename, os.PathLike): filename = os.fspath(filename) + if isinstance(filename, IOBase): + raise TypeError( + f"file objects are not supported by the netCDF4 backend: {filename}" + ) + if not isinstance(filename, str | bytes | memoryview | BytesIOProxy): raise TypeError(f"invalid filename for netCDF4 backend: {filename}") diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 6e1220a36ae..9f6920ca391 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -122,11 +122,13 @@ class flush_only_netcdf_file(netcdf_file_base): # closed when the netcdf_file is garbage collected (via __del__), # and will need to be fixed upstream in scipy. def close(self): - self.flush() + if hasattr(self, "fp") and not self.fp.closed: + self.flush() + self.fp.seek(0) def __del__(self): - # Remove the __del__ method. These files need to be closed explicitly by - # xarray. + # Remove the __del__ method, which in scipy is aliased to close(). + # These files need to be closed explicitly by xarray. pass @@ -211,13 +213,12 @@ def __init__( elif hasattr(filename_or_obj, "seek"): # file object # Note: checking for .seek matches the check for file objects # in scipy.io.netcdf_file - flush_only = mode in "wa" scipy_dataset = _open_scipy_netcdf( filename_or_obj, mode=mode, mmap=mmap, version=version, - flush_only=flush_only, + flush_only=True, ) assert not scipy_dataset.use_mmap # no mmap for file objects manager = DummyFileManager(scipy_dataset) diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index 06d6f4278da..b158862aba0 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -41,7 +41,7 @@ def _datatree_to_netcdf( compute: bool = True, invalid_netcdf: bool = False, auto_complex: bool | None = None, -) -> None | memoryview: +) -> None | memoryview | Delayed: """Implementation of `DataTree.to_netcdf`.""" if format not in [None, *get_args(T_DataTreeNetcdfTypes)]: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 48343fe61ea..f6487fdab2a 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2497,7 +2497,7 @@ def test_deepcopy(self) -> None: class InMemoryNetCDF: engine: T_NetcdfEngine | None - def test_roundtrip(self) -> None: + def test_roundtrip_via_memoryview(self) -> None: original = create_test_data() result = original.to_netcdf(engine=self.engine) roundtrip = load_dataset(result, engine=self.engine) @@ -2513,9 +2513,8 @@ def test_pickle_open_dataset_from_bytes(self) -> None: original = Dataset({"foo": ("x", [1, 2, 3])}) netcdf_bytes = bytes(original.to_netcdf(engine=self.engine)) with open_dataset(netcdf_bytes, engine=self.engine) as roundtrip: - unpickled = pickle.loads(pickle.dumps(roundtrip)) - assert_identical(unpickled, original) - unpickled.close() + with pickle.loads(pickle.dumps(roundtrip)) as unpickled: + assert_identical(unpickled, original) def test_compute_false(self) -> None: original = create_test_data() @@ -2537,16 +2536,6 @@ def test_roundtrip_group_via_memoryview(self) -> None: class FileObjectNetCDF: engine: T_NetcdfEngine - def test_open_twice(self) -> None: - expected = create_test_data() - expected.attrs["foo"] = "bar" - with create_tmp_file() as tmp_file: - expected.to_netcdf(tmp_file, engine=self.engine) - with open(tmp_file, "rb") as f: - with open_dataset(f, engine=self.engine): - with open_dataset(f, engine=self.engine): - pass - def test_file_remains_open(self) -> None: data = Dataset({"foo": ("x", [1, 2, 3])}) f = BytesIO() @@ -4629,23 +4618,13 @@ def roundtrip( async def test_load_async(self) -> None: await super().test_load_async() - @pytest.mark.xfail(reason="not working yet") - def test_open_twice(self): - super().test_open_twice() - - @pytest.mark.xfail( - reason="scipy.io.netcdf_file closes files upon garbage collection" - ) - def test_file_remains_open(self) -> None: - super().test_file_remains_open() - @pytest.mark.skip(reason="cannot pickle file objects") def test_pickle(self) -> None: - pass + super().test_pickle() @pytest.mark.skip(reason="cannot pickle file objects") def test_pickle_dataarray(self) -> None: - pass + super().test_pickle_dataarray() @pytest.mark.parametrize("create_default_indexes", [True, False]) def test_create_default_indexes(self, tmp_path, create_default_indexes) -> None: @@ -4763,6 +4742,18 @@ def test_engine(self) -> None: with pytest.raises(ValueError, match=r"unrecognized engine"): open_dataset(tmp_file, engine="foobar") + with pytest.raises( + TypeError, + match=re.escape("file objects are not supported by the netCDF4 backend"), + ): + data.to_netcdf(BytesIO(), engine="netcdf4") + + with pytest.raises( + TypeError, + match=re.escape("file objects are not supported by the netCDF4 backend"), + ): + open_dataset(BytesIO(), engine="netcdf4") + bytes_io = BytesIO() data.to_netcdf(bytes_io, engine="scipy") with pytest.raises(ValueError, match=r"unrecognized engine"): @@ -5120,6 +5111,16 @@ def test_open_badbytes(self) -> None: with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"): pass + def test_open_twice(self) -> None: + expected = create_test_data() + expected.attrs["foo"] = "bar" + with create_tmp_file() as tmp_file: + expected.to_netcdf(tmp_file, engine=self.engine) + with open(tmp_file, "rb") as f: + with open_dataset(f, engine=self.engine): + with open_dataset(f, engine=self.engine): + pass # should not crash + @requires_scipy def test_open_fileobj(self) -> None: # open in-memory datasets instead of local file paths diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 667e7323758..1679e462680 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -395,9 +395,7 @@ def test_memoryview_write_netcdf4_read_h5netcdf(self, simple_datatree) -> None: roundtrip_dt = load_datatree(memview, engine="h5netcdf") assert_equal(original_dt, roundtrip_dt) - def test_open_datatree(self, unaligned_datatree_nc) -> None: - """Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy.""" - + def test_open_datatree_unaligned_hierarchy(self, unaligned_datatree_nc) -> None: with pytest.raises( ValueError, match=( @@ -569,9 +567,9 @@ class TestPyDAPDatatreeIO: ) simplegroup_datatree_url = "dap4://test.opendap.org/opendap/dap4/SimpleGroup.nc4.h5" - def test_open_datatree(self, url=unaligned_datatree_url) -> None: - """Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy.""" - + def test_open_datatree_unaligned_hierarchy( + self, url=unaligned_datatree_url + ) -> None: with pytest.raises( ValueError, match=( @@ -893,8 +891,9 @@ def test_open_groups_round_trip(self, tmpdir, simple_datatree, zarr_format) -> N @pytest.mark.filterwarnings( "ignore:Failed to open Zarr store with consolidated metadata:RuntimeWarning" ) - def test_open_datatree(self, unaligned_datatree_zarr_factory, zarr_format) -> None: - """Test if `open_datatree` fails to open a zarr store with an unaligned group hierarchy.""" + def test_open_datatree_unaligned_hierarchy( + self, unaligned_datatree_zarr_factory, zarr_format + ) -> None: storepath = unaligned_datatree_zarr_factory(zarr_format=zarr_format) with pytest.raises( From d6336866d33ce36f9aa9e7e48631a7912144bf43 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Sep 2025 00:19:55 -0400 Subject: [PATCH 15/37] remove unnecessary seek --- xarray/backends/scipy_.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 9f6920ca391..0446285edd5 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -124,7 +124,6 @@ class flush_only_netcdf_file(netcdf_file_base): def close(self): if hasattr(self, "fp") and not self.fp.closed: self.flush() - self.fp.seek(0) def __del__(self): # Remove the __del__ method, which in scipy is aliased to close(). From dfc651d262ce9759069b2b564c5df8d6ad92364d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Sep 2025 00:34:04 -0400 Subject: [PATCH 16/37] Add release notes --- doc/whats-new.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 72344df4658..b72350b9b1f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,10 @@ v2025.09.1 (unreleased) New Features ~~~~~~~~~~~~ +- ``engine='netcdf4'`` now supports reading and writing in-memory netCDF files. + All of Xarray's netCDF backends now support in-memory reads and writes + (:pull:`10624`). + By `Stephan Hoyer `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -29,6 +33,17 @@ Deprecations Bug fixes ~~~~~~~~~ +- Xarray objects opened from file-like objects with ``engine='h5netcdf'`` can + now be pickled, as long as the underlying file-like object also support + pickle + (:pull:`10624`). + By `Stephan Hoyer `_. +- Closing Xarray objects opened from file-like objects with ```engine='scipy'`` + no longer closes the underlying file, consistent the h5netcdf backend + (:pull:`10624`). + By `Stephan Hoyer `_. + + Documentation ~~~~~~~~~~~~~ From f267148680870a12049e3c47f5ff62ac61552cb9 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 9 Sep 2025 00:29:37 -0400 Subject: [PATCH 17/37] fix test failures --- doc/whats-new.rst | 7 +++-- xarray/backends/locks.py | 2 +- xarray/backends/scipy_.py | 49 ++++++++++++++++------------------- xarray/tests/test_backends.py | 29 ++++++++++++++------- 4 files changed, 45 insertions(+), 42 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b72350b9b1f..85dbd955672 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,17 +34,16 @@ Bug fixes ~~~~~~~~~ - Xarray objects opened from file-like objects with ``engine='h5netcdf'`` can - now be pickled, as long as the underlying file-like object also support + now be pickled, as long as the underlying file-like object also supports pickle - (:pull:`10624`). + (:issue:`10712`). By `Stephan Hoyer `_. -- Closing Xarray objects opened from file-like objects with ```engine='scipy'`` +- Closing Xarray objects opened from file-like objects with ```engine='scipy'`` no longer closes the underlying file, consistent the h5netcdf backend (:pull:`10624`). By `Stephan Hoyer `_. - Documentation ~~~~~~~~~~~~~ diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 2424d8f6fa9..784443544ee 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -100,7 +100,7 @@ def _get_threaded_lock(key: str) -> threading.Lock: return lock -def _get_multiprocessing_lock(key: str) -> multiprocessing.Lock: +def _get_multiprocessing_lock(key: str) -> Lock: # TODO: make use of the key -- maybe use locket.py? # https://github.com/mwilliamson/locket.py del key # unused diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 0446285edd5..a60dfea2c4c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -35,12 +35,6 @@ ) from xarray.core.variable import Variable -try: - from scipy.io import netcdf_file as netcdf_file_base -except ImportError: - netcdf_file_base = object - - if TYPE_CHECKING: import scipy.io @@ -110,30 +104,31 @@ def __setitem__(self, key, value): raise -class flush_only_netcdf_file(netcdf_file_base): - # scipy.io.netcdf_file.close() incorrectly closes file objects that - # were passed in as constructor arguments: - # https://github.com/scipy/scipy/issues/13905 - - # Instead of closing such files, only call flush(), which is - # equivalent as long as the netcdf_file object is not mmapped. - # This suffices to keep BytesIO objects open long enough to read - # their contents from to_netcdf(), but underlying files still get - # closed when the netcdf_file is garbage collected (via __del__), - # and will need to be fixed upstream in scipy. - def close(self): - if hasattr(self, "fp") and not self.fp.closed: - self.flush() - - def __del__(self): - # Remove the __del__ method, which in scipy is aliased to close(). - # These files need to be closed explicitly by xarray. - pass - - def _open_scipy_netcdf(filename, mode, mmap, version, flush_only=False): import scipy.io + # define inside a helper function to ensure the scipy import is lazy + class flush_only_netcdf_file(scipy.io.netcdf_file): + # scipy.io.netcdf_file.close() incorrectly closes file objects that + # were passed in as constructor arguments: + # https://github.com/scipy/scipy/issues/13905 + + # Instead of closing such files, only call flush(), which is + # equivalent as long as the netcdf_file object is not mmapped. + # This suffices to keep BytesIO objects open long enough to read + # their contents from to_netcdf(), but underlying files still get + # closed when the netcdf_file is garbage collected (via __del__), + # and will need to be fixed upstream in scipy. + def close(self): + if hasattr(self, "fp") and not self.fp.closed: + self.flush() + self.fp.seek(0) # allow file to be read again + + def __del__(self): + # Remove the __del__ method, which in scipy is aliased to close(). + # These files need to be closed explicitly by xarray. + pass + netcdf_file = flush_only_netcdf_file if flush_only else scipy.io.netcdf_file # if the string ends with .gz, then gunzip and open as netcdf file diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index fe46671c7c1..8758abd0c9b 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -116,12 +116,12 @@ with contextlib.suppress(ImportError): import netCDF4 as nc4 -try: +with contextlib.suppress(ImportError): import dask import dask.array as da -except ImportError: - pass +with contextlib.suppress(ImportError): + import fsspec if has_zarr: import zarr @@ -633,16 +633,13 @@ def test_pickle(self) -> None: with pickle.loads(raw_pickle) as unpickled_ds: assert_identical(expected, unpickled_ds) - @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") def test_pickle_dataarray(self) -> None: expected = Dataset({"foo": ("x", [42])}) with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: with roundtripped: raw_pickle = pickle.dumps(roundtripped["foo"]) - # TODO: figure out how to explicitly close the file for the - # unpickled DataArray? - unpickled = pickle.loads(raw_pickle) - assert_identical(expected["foo"], unpickled) + with pickle.loads(raw_pickle) as unpickled: + assert_identical(expected["foo"], unpickled) def test_dataset_caching(self) -> None: expected = Dataset({"foo": ("x", [5, 6, 7])}) @@ -658,7 +655,6 @@ def test_dataset_caching(self) -> None: _ = actual.foo.values # no caching assert not actual.foo.variable._in_memory - @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") def test_roundtrip_None_variable(self) -> None: expected = Dataset({None: (("x", "y"), [[0, 1], [2, 3]])}) with self.roundtrip(expected) as actual: @@ -5113,7 +5109,6 @@ def test_open_badbytes(self) -> None: def test_open_twice(self) -> None: expected = create_test_data() - expected.attrs["foo"] = "bar" with create_tmp_file() as tmp_file: expected.to_netcdf(tmp_file, engine=self.engine) with open(tmp_file, "rb") as f: @@ -5154,6 +5149,20 @@ def test_open_fileobj(self) -> None: with open_dataset(f): # ensure file gets closed pass + @requires_fsspec + def test_fsspec(self) -> None: + expected = create_test_data() + with create_tmp_file() as tmp_file: + expected.to_netcdf(tmp_file, engine="h5netcdf") + + with fsspec.open(tmp_file, "rb") as f: + with open_dataset(f, engine="h5netcdf") as actual: + assert_identical(actual, expected) + + # fsspec.open() creates a pickleable file, unlike open() + with pickle.loads(pickle.dumps(actual)) as unpickled: + assert_identical(unpickled, expected) + @requires_h5netcdf class TestH5NetCDFInMemoryData(InMemoryNetCDFWithGroups): From 21006738826240ee33f9d5451870336a2f88bb7d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 9 Sep 2025 01:37:32 -0400 Subject: [PATCH 18/37] tweak whats new --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f60a5facb6e..483bc124771 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,9 +35,9 @@ Bug fixes - Xarray objects opened from file-like objects with ``engine='h5netcdf'`` can now be pickled, as long as the underlying file-like object also supports - pickle - (:issue:`10712`). + pickle (:issue:`10712`). By `Stephan Hoyer `_. + - Closing Xarray objects opened from file-like objects with ```engine='scipy'`` no longer closes the underlying file, consistent the h5netcdf backend (:pull:`10624`). From 5a2dc42f9e1d3f77d59efa4177febc5d7a89a675 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 9 Sep 2025 01:57:30 -0400 Subject: [PATCH 19/37] Make scipy import no longer lazy --- xarray/backends/scipy_.py | 51 ++++++++++++++++++++---------------- xarray/tests/test_plugins.py | 2 +- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index a60dfea2c4c..db188d88b9c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -35,6 +35,12 @@ ) from xarray.core.variable import Variable +try: + from scipy.io import netcdf_file as netcdf_file_base +except ImportError: + netcdf_file_base = object + + if TYPE_CHECKING: import scipy.io @@ -104,31 +110,32 @@ def __setitem__(self, key, value): raise +# TODO: Make the scipy import lazy again after upstreaming these fixes. +class flush_only_netcdf_file(netcdf_file_base): + # scipy.io.netcdf_file.close() incorrectly closes file objects that + # were passed in as constructor arguments: + # https://github.com/scipy/scipy/issues/13905 + + # Instead of closing such files, only call flush(), which is + # equivalent as long as the netcdf_file object is not mmapped. + # This suffices to keep BytesIO objects open long enough to read + # their contents from to_netcdf(), but underlying files still get + # closed when the netcdf_file is garbage collected (via __del__), + # and will need to be fixed upstream in scipy. + def close(self): + if hasattr(self, "fp") and not self.fp.closed: + self.flush() + self.fp.seek(0) # allow file to be read again + + def __del__(self): + # Remove the __del__ method, which in scipy is aliased to close(). + # These files need to be closed explicitly by xarray. + pass + + def _open_scipy_netcdf(filename, mode, mmap, version, flush_only=False): import scipy.io - # define inside a helper function to ensure the scipy import is lazy - class flush_only_netcdf_file(scipy.io.netcdf_file): - # scipy.io.netcdf_file.close() incorrectly closes file objects that - # were passed in as constructor arguments: - # https://github.com/scipy/scipy/issues/13905 - - # Instead of closing such files, only call flush(), which is - # equivalent as long as the netcdf_file object is not mmapped. - # This suffices to keep BytesIO objects open long enough to read - # their contents from to_netcdf(), but underlying files still get - # closed when the netcdf_file is garbage collected (via __del__), - # and will need to be fixed upstream in scipy. - def close(self): - if hasattr(self, "fp") and not self.fp.closed: - self.flush() - self.fp.seek(0) # allow file to be read again - - def __del__(self): - # Remove the __del__ method, which in scipy is aliased to close(). - # These files need to be closed explicitly by xarray. - pass - netcdf_file = flush_only_netcdf_file if flush_only else scipy.io.netcdf_file # if the string ends with .gz, then gunzip and open as netcdf file diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index f1342a1f82a..c23a5487bd6 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -227,7 +227,7 @@ def test_lazy_import() -> None: "numbagg", "pint", "pydap", - "scipy", + # "scipy", # TODO: xarray.backends.scipy_ is currently not lazy "sparse", "zarr", ] From cdc75236c39f7644233a9cf0a99dd5d25de19a03 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 9 Sep 2025 10:27:19 -0400 Subject: [PATCH 20/37] Update doc/whats-new.rst MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Kai Mühlbauer --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 483bc124771..d2aad414315 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,7 +39,7 @@ Bug fixes By `Stephan Hoyer `_. - Closing Xarray objects opened from file-like objects with ```engine='scipy'`` - no longer closes the underlying file, consistent the h5netcdf backend + no longer closes the underlying file, consistent with the h5netcdf backend (:pull:`10624`). By `Stephan Hoyer `_. From 190e962052e2aef8dac8fa00941ed6f6e253338c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 15 Sep 2025 18:01:02 -0700 Subject: [PATCH 21/37] Fix error on Windows --- xarray/tests/test_backends_datatree.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 1679e462680..8d03781592d 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -33,6 +33,9 @@ import netCDF4 as nc4 +ON_WINDOWS = sys.platform == "win32" + + class TestNetCDF4DataTree(_TestNetCDF4Data): @contextlib.contextmanager def open(self, path, **kwargs): @@ -303,9 +306,12 @@ def test_compute_false(self, tmpdir, simple_datatree): original_dt = simple_datatree.chunk() result = original_dt.to_netcdf(filepath, engine=self.engine, compute=False) - with open_datatree(filepath, engine=self.engine) as in_progress_dt: - assert in_progress_dt.isomorphic(original_dt) - assert not in_progress_dt.equals(original_dt) + if not ON_WINDOWS: + # File at filepath is not closed until .compute() is called. On + # Windows, this means we can't open it yet. + with open_datatree(filepath, engine=self.engine) as in_progress_dt: + assert in_progress_dt.isomorphic(original_dt) + assert not in_progress_dt.equals(original_dt) result.compute() with open_datatree(filepath, engine=self.engine) as written_dt: From 4d224684e7f3e7367be49eb0eb2355651462563f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 10 Sep 2025 09:07:52 -0700 Subject: [PATCH 22/37] Silence Zarr v3 warnings (#10731) responsibel for 1400 / 1440 warnings; and don't think there's much we can do about them --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 95a3f65dd8a..5a71ac7de00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -363,6 +363,9 @@ filterwarnings = [ "default:Zarr-Python is not in alignment with the final V3 specification", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "default:is currently not part .* the Zarr version 3 specification.", + # Zarr V3 data type specifications warnings - very repetitive + "ignore:The data type .* does not have a Zarr V3 specification", + "ignore:Consolidated metadata is currently not part", # TODO: remove once we know how to deal with a changed signature in protocols "default:::xarray.tests.test_strategies", ] From 8eedfd82a69b7f24e6fbe29c6a950366df95efe6 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 10 Sep 2025 20:18:19 +0200 Subject: [PATCH 23/37] propagate attrs on coords in `Dataset.map` (#10602) * check that weighted ops propagate attrs on coords * propagate attrs on coords in `map` if keep_attrs * directly check that `map` propagates attrs on coords * whats-new --- doc/whats-new.rst | 2 ++ xarray/core/dataset.py | 13 ++++++++++++- xarray/tests/test_dataset.py | 32 ++++++++++++++++++++++++++++++++ xarray/tests/test_weighted.py | 11 +++++++++++ 4 files changed, 57 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 83d5ec3e7cd..5d4d9896180 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -48,6 +48,8 @@ Bug fixes - Fix error when encoding an empty :py:class:`numpy.datetime64` array (:issue:`10722`, :pull:`10723`). By `Spencer Clark `_. +- Propagation coordinate attrs in :py:meth:`xarray.Dataset.map` (:issue:`9317`, :pull:`10602`). + By `Justus Magin `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 11f56d3ad44..650c05df73b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6929,11 +6929,22 @@ def map( k: maybe_wrap_array(v, func(v, *args, **kwargs)) for k, v in self.data_vars.items() } + coord_vars, indexes = merge_coordinates_without_align( + [v.coords for v in variables.values()] + ) + coords = Coordinates._construct_direct(coords=coord_vars, indexes=indexes) + if keep_attrs: for k, v in variables.items(): v._copy_attrs_from(self.data_vars[k]) + + for k, v in coords.items(): + if k not in self.coords: + continue + v._copy_attrs_from(self.coords[k]) + attrs = self.attrs if keep_attrs else None - return type(self)(variables, attrs=attrs) + return type(self)(variables, coords=coords, attrs=attrs) def apply( self, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2cafb1f2fc1..959177dec68 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6236,6 +6236,38 @@ def scale(x, multiple=1): expected = data.drop_vars("time") # time is not used on a data var assert_equal(expected, actual) + def test_map_coords_attrs(self) -> None: + ds = xr.Dataset( + { + "a": ( + ["x", "y", "z"], + np.arange(24).reshape(3, 4, 2), + {"attr1": "value1"}, + ), + "b": ("y", np.arange(4), {"attr2": "value2"}), + }, + coords={ + "x": ("x", np.array([-1, 0, 1]), {"attr3": "value3"}), + "z": ("z", list("ab"), {"attr4": "value4"}), + }, + ) + + def func(arr): + if "y" not in arr.dims: + return arr + + # drop attrs from coords + return arr.mean(dim="y").drop_attrs() + + expected = ds.mean(dim="y", keep_attrs=True) + actual = ds.map(func, keep_attrs=True) + + assert_identical(actual, expected) + assert actual["x"].attrs + + ds["x"].attrs["y"] = "x" + assert ds["x"].attrs != actual["x"].attrs + def test_apply_pending_deprecated_map(self) -> None: data = create_test_data() data.attrs["foo"] = "bar" diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index e9be98ab76b..5d27794cc8d 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -770,6 +770,17 @@ def test_weighted_operations_keep_attr_da_in_ds(operation): assert data.a.attrs == result.a.attrs +def test_weighted_mean_keep_attrs_ds(): + weights = DataArray(np.random.randn(2)) + data = Dataset( + {"a": (["dim_0", "dim_1"], np.random.randn(2, 2), dict(attr="data"))}, + coords={"dim_1": ("dim_1", ["a", "b"], {"attr1": "value1"})}, + ) + + result = data.weighted(weights).mean(dim="dim_0", keep_attrs=True) + assert data.coords["dim_1"].attrs == result.coords["dim_1"].attrs + + @pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean", "quantile")) @pytest.mark.parametrize("as_dataset", (True, False)) def test_weighted_bad_dim(operation, as_dataset): From 18af4b30d8c86dd97e4aa2254aa00c5a9a55e3a1 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Wed, 10 Sep 2025 15:51:48 -0400 Subject: [PATCH 24/37] Add pep-723 style script to bug report issue template (#10707) * Add pep-723 style script to bug report issue template * make suggestion rather than requirement * Update .github/ISSUE_TEMPLATE/bugreport.yml Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * use complete * Update .github/ISSUE_TEMPLATE/bugreport.yml Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * Update .github/ISSUE_TEMPLATE/bugreport.yml Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --------- Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/bugreport.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index 5bd7efd12f1..dca031aca78 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -39,6 +39,23 @@ body: - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) - [Craft Minimal Bug Reports](https://matthewrocklin.com/minimal-bug-reports) + Consider listing additional or specific dependencies in [inline script metadata](https://packaging.python.org/en/latest/specifications/inline-script-metadata/#example) + so that calling `uv run issue.py` shows the issue when copied into `issue.py`. (not strictly required) + value: | + ```python + # /// script + # requires-python = ">=3.11" + # dependencies = [ + # "xarray[complete]@git+https://github.com/pydata/xarray.git@main, + # ] + # /// + # + # This script automatically imports the development branch of xarray to check for issues + + import xarray as xr + xr.show_versions() + # your reproducer code ... + ``` options: - label: Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray. - label: Complete example — the example is self-contained, including all data and the text of any traceback. From 678c0040b392d1ac8978dc3112da1f29bd86ed02 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 10 Sep 2025 20:26:38 -0600 Subject: [PATCH 25/37] Allow `mean` with time dtypes (#10227) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow `mean` with time dtypes Closes #5897 Closes #6995 Closes #10217 * Allow mean() to work with datetime dtypes - Changed numeric_only from True to False in mean() aggregations - Added Dataset.mean() override to filter out string variables while preserving datetime/timedelta types - Only filters data variables, preserves all coordinates - Added comprehensive tests for datetime mean with edge cases including NaT handling - Tests cover Dataset, DataArray, groupby, and mixed type scenarios This enables mean() to work with datetime64 and timedelta64 types as requested in PR #10227 while preventing errors from string variables. * Skip cftime test when cftime not available Add pytest.mark.skipif decorator to test_mean_preserves_non_string_object_arrays to skip the test in minimal dependency environments where cftime is not installed. * Fix string/datetime handling in mean() aggregations - Revert numeric_only back to True for mean() to prevent strings from being included - Add datetime64/timedelta64 types to numeric_only checks in Dataset.reduce() and flox path - Also check for object arrays containing datetime-like objects (cftime dates) - This allows mean() to work with datetime types while excluding strings that would cause errors * Trigger CI workflow * Format groupby.py numeric_only condition Auto-formatted the multi-line condition for better readability * Apply formatting changes to dataset.py and test_groupby.py - Auto-formatted multi-line conditions for better readability 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix numeric_only parameter in groupby mean for non-flox path The generated mean() methods for DatasetGroupBy and ResampleGroupBy were incorrectly passing numeric_only=False in the non-flox path, causing string variables to fail during reduction. Changed to numeric_only=True to match the flox path behavior. This fixes test_groupby_dataset_reduce_ellipsis failures. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Simplify dtype checking with _is_numeric_aggregatable_dtype helper Created a canonical helper function to check if a dtype can be used in numeric aggregations. This replaces complex repeated conditionals in dataset.py and groupby.py with a single, well-documented function. The helper checks for: - Numeric types (int, float, complex) - Boolean type - Datetime types (datetime64, timedelta64) - Object arrays containing datetime-like objects (e.g., cftime) 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Address review feedback from @dcherian - Document datetime mean behavior in generate_aggregations.py - Simplify test assertion using isnull().item() instead of pd.notna - Remove redundant test_mean_dataarray_datetime test - Add comprehensive tests for linked issues: - Issue #5897: Test mean with cftime objects - Issue #6995: Test groupby_bins with datetime64 mean - Issue #10217: Test groupby_bins mean on time series data 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Fix test_mean_with_cftime_objects for non-dask environments The test was unconditionally using dask operations which caused failures in the "all-but-dask" CI environment. Now the dask-specific tests are only run when dask is available. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Use standard @requires_dask decorator for dask-specific tests Split test_mean_with_cftime_objects into two tests: - Base test runs without dask - Dask-specific test uses @requires_dask decorator This follows xarray's standard pattern for dependency-specific tests and is cleaner than using if has_dask conditionals. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Add testing guidelines and use proper decorators - Created xarray/tests/CLAUDE.md with comprehensive guidelines for handling optional dependencies in tests - Updated cftime tests to use @requires_cftime decorator instead of pytest.importorskip, following xarray's standard patterns - This ensures consistent handling across CI environments 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix CLAUDE.md blackdoc formatting Ensure all Python code blocks have complete, valid syntax for blackdoc. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Simplify cftime check in _is_numeric_aggregatable_dtype As suggested by @dcherian in review comment r2337966239, directly use _contains_cftime_datetimes(var._data) instead of the more complex check. This is cleaner since _contains_cftime_datetimes already handles the object dtype check internally. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Streamline CLAUDE.md and move import to top of file - Made CLAUDE.md more concise by removing verbose decorator listings - Moved _is_numeric_aggregatable_dtype import to top of groupby.py as suggested by @dcherian in review comment r2337968851 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Address review comments: clarify numeric_only scope and merge tests - Move datetime/timedelta note to global _NUMERIC_ONLY_NOTES constant (addresses comment r2337970143) - Merge test_mean_preserves_non_string_object_arrays into test_mean_with_cftime_objects (addresses comment r2337128692) - Both changes address reviewer feedback about simplification * Clarify that @requires decorators should be used instead of skipif patterns Update testing guidelines to explicitly discourage pytest.mark.skipif in parametrize, recommending splitting tests or using @requires decorators * Fix CLAUDE.md blackdoc formatting Co-authored-by: Claude --------- Co-authored-by: Maximilian Roos Co-authored-by: Claude Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Claude --- xarray/core/_aggregations.py | 16 -- xarray/core/common.py | 18 +++ xarray/core/dataset.py | 4 +- xarray/core/groupby.py | 8 +- xarray/namedarray/_aggregations.py | 4 - xarray/tests/CLAUDE.md | 132 ++++++++++++++++ xarray/tests/test_groupby.py | 228 +++++++++++++++++++++++++++ xarray/util/generate_aggregations.py | 2 +- 8 files changed, 387 insertions(+), 25 deletions(-) create mode 100644 xarray/tests/CLAUDE.md diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 6b1029791ea..adc064840de 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -1776,10 +1776,6 @@ def mean( :ref:`agg` User guide on reduction or aggregation operations. - Notes - ----- - Non-numeric variables will be removed prior to reducing. - Examples -------- >>> da = xr.DataArray( @@ -2948,10 +2944,6 @@ def mean( :ref:`agg` User guide on reduction or aggregation operations. - Notes - ----- - Non-numeric variables will be removed prior to reducing. - Examples -------- >>> da = xr.DataArray( @@ -4231,8 +4223,6 @@ def mean( Pass flox-specific keyword arguments in ``**kwargs``. See the `flox documentation `_ for more. - Non-numeric variables will be removed prior to reducing. - Examples -------- >>> da = xr.DataArray( @@ -5729,8 +5719,6 @@ def mean( Pass flox-specific keyword arguments in ``**kwargs``. See the `flox documentation `_ for more. - Non-numeric variables will be removed prior to reducing. - Examples -------- >>> da = xr.DataArray( @@ -7188,8 +7176,6 @@ def mean( Pass flox-specific keyword arguments in ``**kwargs``. See the `flox documentation `_ for more. - Non-numeric variables will be removed prior to reducing. - Examples -------- >>> da = xr.DataArray( @@ -8578,8 +8564,6 @@ def mean( Pass flox-specific keyword arguments in ``**kwargs``. See the `flox documentation `_ for more. - Non-numeric variables will be removed prior to reducing. - Examples -------- >>> da = xr.DataArray( diff --git a/xarray/core/common.py b/xarray/core/common.py index a190766b01a..8f789c35445 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -2108,3 +2108,21 @@ def _contains_datetime_like_objects(var: T_Variable) -> bool: np.datetime64, np.timedelta64, or cftime.datetime) """ return is_np_datetime_like(var.dtype) or contains_cftime_datetimes(var) + + +def _is_numeric_aggregatable_dtype(var: T_Variable) -> bool: + """Check if a variable's dtype can be used in numeric aggregations like mean(). + + This includes: + - Numeric types (int, float, complex) + - Boolean type + - Datetime types (datetime64, timedelta64) + - Object arrays containing datetime-like objects (e.g., cftime) + """ + return ( + np.issubdtype(var.dtype, np.number) + or (var.dtype == np.bool_) + or np.issubdtype(var.dtype, np.datetime64) + or np.issubdtype(var.dtype, np.timedelta64) + or _contains_cftime_datetimes(var._data) + ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 650c05df73b..377f6db28f3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -40,6 +40,7 @@ from xarray.core.common import ( DataWithCoords, _contains_datetime_like_objects, + _is_numeric_aggregatable_dtype, get_chunksizes, ) from xarray.core.coordinates import ( @@ -6847,8 +6848,7 @@ def reduce( and ( not reduce_dims or not numeric_only - or np.issubdtype(var.dtype, np.number) - or (var.dtype == np.bool_) + or _is_numeric_aggregatable_dtype(var) ) ): # prefer to aggregate over axis=None rather than diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7a537355ac4..723a89ea3b7 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -22,7 +22,11 @@ DataArrayGroupByAggregations, DatasetGroupByAggregations, ) -from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce +from xarray.core.common import ( + ImplementsArrayReduce, + ImplementsDatasetReduce, + _is_numeric_aggregatable_dtype, +) from xarray.core.coordinates import Coordinates, coordinates_from_variable from xarray.core.duck_array_ops import where from xarray.core.formatting import format_array_flat @@ -1068,7 +1072,7 @@ def _flox_reduce( name: var for name, var in variables.items() if ( - not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + not _is_numeric_aggregatable_dtype(var) # this avoids dropping any levels of a MultiIndex, which raises # a warning and name not in midx_grouping_vars diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py index 139cea83b5b..c5726ef9251 100644 --- a/xarray/namedarray/_aggregations.py +++ b/xarray/namedarray/_aggregations.py @@ -352,10 +352,6 @@ def mean( :ref:`agg` User guide on reduction or aggregation operations. - Notes - ----- - Non-numeric variables will be removed prior to reducing. - Examples -------- >>> from xarray.namedarray.core import NamedArray diff --git a/xarray/tests/CLAUDE.md b/xarray/tests/CLAUDE.md new file mode 100644 index 00000000000..3d94a0509bc --- /dev/null +++ b/xarray/tests/CLAUDE.md @@ -0,0 +1,132 @@ +# Testing Guidelines for xarray + +## Handling Optional Dependencies + +xarray has many optional dependencies that may not be available in all testing environments. Always use the standard decorators and patterns when writing tests that require specific dependencies. + +### Standard Decorators + +**ALWAYS use decorators** like `@requires_dask`, `@requires_cftime`, etc. instead of conditional `if` statements. + +All available decorators are defined in `xarray/tests/__init__.py` (look for `requires_*` decorators). + +### DO NOT use conditional imports or skipif + +❌ **WRONG - Do not do this:** + +```python +def test_mean_with_cftime(): + if has_dask: # WRONG! + ds = ds.chunk({}) + result = ds.mean() +``` + +❌ **ALSO WRONG - Avoid pytest.mark.skipif in parametrize:** + +```python +@pytest.mark.parametrize( + "chunk", + [ + pytest.param( + True, marks=pytest.mark.skipif(not has_dask, reason="requires dask") + ), + False, + ], +) +def test_something(chunk): ... +``` + +✅ **CORRECT - Do this instead:** + +```python +def test_mean_with_cftime(): + # Test without dask + result = ds.mean() + + +@requires_dask +def test_mean_with_cftime_dask(): + # Separate test for dask functionality + ds = ds.chunk({}) + result = ds.mean() +``` + +✅ **OR for parametrized tests, split them:** + +```python +def test_something_without_dask(): + # Test the False case + ... + + +@requires_dask +def test_something_with_dask(): + # Test the True case with dask + ... +``` + +### Multiple dependencies + +When a test requires multiple optional dependencies: + +```python +@requires_dask +@requires_scipy +def test_interpolation_with_dask(): ... +``` + +### Importing optional dependencies in tests + +For imports within test functions, use `pytest.importorskip`: + +```python +def test_cftime_functionality(): + cftime = pytest.importorskip("cftime") + # Now use cftime +``` + +### Common patterns + +1. **Split tests by dependency** - Don't mix optional dependency code with base functionality: + + ```python + def test_base_functionality(): + # Core test without optional deps + result = ds.mean() + assert result is not None + + + @requires_dask + def test_dask_functionality(): + # Dask-specific test + ds_chunked = ds.chunk({}) + result = ds_chunked.mean() + assert result is not None + ``` + +2. **Use fixtures for dependency-specific setup**: + + ```python + @pytest.fixture + def dask_array(): + pytest.importorskip("dask.array") + import dask.array as da + + return da.from_array([1, 2, 3], chunks=2) + ``` + +3. **Check available implementations**: + + ```python + from xarray.core.duck_array_ops import available_implementations + + + @pytest.mark.parametrize("implementation", available_implementations()) + def test_with_available_backends(implementation): ... + ``` + +### Key Points + +- CI environments intentionally exclude certain dependencies (e.g., `all-but-dask`, `bare-minimum`) +- A test failing in "all-but-dask" because it uses dask is a test bug, not a CI issue +- Look at similar existing tests for patterns to follow diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 94e88fa1dd8..202b729abcf 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3606,6 +3606,234 @@ def test_season_resampler_groupby_identical(self): assert_identical(rs, gb) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param( + True, marks=pytest.mark.skipif(not has_dask, reason="requires dask") + ), + False, + ], +) +def test_datetime_mean(chunk, use_cftime): + ds = xr.Dataset( + { + "var1": ( + ("time",), + xr.date_range( + "2021-10-31", periods=10, freq="D", use_cftime=use_cftime + ), + ), + "var2": (("x",), list(range(10))), + } + ) + if chunk: + ds = ds.chunk() + assert "var1" in ds.groupby("x").mean("time") + assert "var1" in ds.mean("x") + + +def test_mean_with_mixed_types(): + """Test that mean correctly handles datasets with mixed types including strings""" + ds = xr.Dataset( + { + "numbers": (("x",), [1.0, 2.0, 3.0, 4.0]), + "integers": (("x",), [10, 20, 30, 40]), + "strings": (("x",), ["a", "b", "c", "d"]), + "datetime": ( + ("x",), + pd.date_range("2021-01-01", periods=4, freq="D"), + ), + "timedelta": ( + ("x",), + pd.timedelta_range("1 day", periods=4, freq="D"), + ), + } + ) + + # Direct mean should exclude strings but include datetime/timedelta + result = ds.mean() + assert "numbers" in result.data_vars + assert "integers" in result.data_vars + assert "strings" not in result.data_vars + assert "datetime" in result.data_vars + assert "timedelta" in result.data_vars + + # Also test mean with specific dimension + result_dim = ds.mean("x") + assert "numbers" in result_dim.data_vars + assert "integers" in result_dim.data_vars + assert "strings" not in result_dim.data_vars + assert "datetime" in result_dim.data_vars + assert "timedelta" in result_dim.data_vars + + +def test_mean_with_string_coords(): + """Test that mean works when strings are in coordinates, not data vars""" + ds = xr.Dataset( + { + "temperature": (("city", "time"), np.random.rand(3, 4)), + "humidity": (("city", "time"), np.random.rand(3, 4)), + }, + coords={ + "city": ["New York", "London", "Tokyo"], + "time": pd.date_range("2021-01-01", periods=4, freq="D"), + }, + ) + + # Mean across string coordinate should work + result = ds.mean("city") + assert result.sizes == {"time": 4} + assert "temperature" in result.data_vars + assert "humidity" in result.data_vars + + # Groupby with string coordinate should work + grouped = ds.groupby("city") + result_grouped = grouped.mean() + assert "temperature" in result_grouped.data_vars + assert "humidity" in result_grouped.data_vars + + +def test_mean_datetime_edge_cases(): + """Test mean with datetime edge cases like NaT""" + # Test with NaT values + dates_with_nat = pd.date_range("2021-01-01", periods=4, freq="D") + dates_with_nat_array = dates_with_nat.values.copy() + dates_with_nat_array[1] = np.datetime64("NaT") + + ds = xr.Dataset( + { + "dates": (("x",), dates_with_nat_array), + "values": (("x",), [1.0, 2.0, 3.0, 4.0]), + } + ) + + # Mean should handle NaT properly (skipna behavior) + result = ds.mean() + assert "dates" in result.data_vars + assert "values" in result.data_vars + # The mean should skip NaT and compute mean of the other 3 dates + assert not result.dates.isnull().item() + + # Test with timedelta + timedeltas = pd.timedelta_range("1 day", periods=4, freq="D") + ds_td = xr.Dataset( + { + "timedeltas": (("x",), timedeltas), + "values": (("x",), [1.0, 2.0, 3.0, 4.0]), + } + ) + + result_td = ds_td.mean() + assert "timedeltas" in result_td.data_vars + assert result_td["timedeltas"].values == np.timedelta64( + 216000000000000, "ns" + ) # 2.5 days + + +@requires_cftime +def test_mean_with_cftime_objects(): + """Test mean with cftime objects (issue #5897)""" + ds = xr.Dataset( + { + "var1": ( + ("time",), + xr.date_range("2021-10-31", periods=10, freq="D", use_cftime=True), + ), + "var2": (("x",), list(range(10))), + } + ) + + # Test averaging over time dimension - var1 should be included + result_time = ds.mean("time") + assert "var1" in result_time.data_vars + assert "var2" not in result_time.dims + + # Test averaging over x dimension - should work normally + result_x = ds.mean("x") + assert "var2" in result_x.data_vars + assert "var1" in result_x.data_vars + assert result_x.var2.item() == 4.5 # mean of 0-9 + + # Test that mean preserves object arrays containing datetime-like objects + import cftime + + dates = np.array( + [cftime.DatetimeNoLeap(2021, i, 1) for i in range(1, 5)], dtype=object + ) + ds2 = xr.Dataset( + { + "cftime_dates": (("x",), dates), + "numbers": (("x",), [1.0, 2.0, 3.0, 4.0]), + "object_strings": (("x",), np.array(["a", "b", "c", "d"], dtype=object)), + } + ) + + # Mean should include cftime dates but not string objects + result = ds2.mean() + assert "cftime_dates" in result.data_vars + assert "numbers" in result.data_vars + assert "object_strings" not in result.data_vars + + +@requires_dask +@requires_cftime +def test_mean_with_cftime_objects_dask(): + """Test mean with cftime objects using dask backend (issue #5897)""" + ds = xr.Dataset( + { + "var1": ( + ("time",), + xr.date_range("2021-10-31", periods=10, freq="D", use_cftime=True), + ), + "var2": (("x",), list(range(10))), + } + ) + + # Test with dask backend + dsc = ds.chunk({}) + result_time_dask = dsc.mean("time") + assert "var1" in result_time_dask.data_vars + + result_x_dask = dsc.mean("x") + assert "var2" in result_x_dask.data_vars + assert result_x_dask.var2.compute().item() == 4.5 + + +def test_groupby_bins_datetime_mean(): + """Test groupby_bins with datetime mean (issue #6995)""" + times = pd.date_range("2020-01-01", "2020-02-01", freq="1h") + index = np.arange(len(times)) + bins = np.arange(0, len(index), 5) + + ds = xr.Dataset( + {"time": ("index", times), "float": ("index", np.linspace(0, 1, len(index)))}, + coords={"index": index}, + ) + + # The time variable should be preserved and averaged + result = ds.groupby_bins("index", bins).mean() + assert "time" in result.data_vars + assert "float" in result.data_vars + assert result.time.dtype == np.dtype("datetime64[ns]") + + +def test_groupby_bins_mean_time_series(): + """Test groupby_bins mean on time series data (issue #10217)""" + ds = xr.Dataset( + { + "measurement": ("trial", np.arange(0, 100, 10)), + "time": ("trial", pd.date_range("20240101T1500", "20240101T1501", 10)), + } + ) + + # Time variable should be preserved in the aggregation + ds_agged = ds.groupby_bins("trial", 5).mean() + assert "time" in ds_agged.data_vars + assert "measurement" in ds_agged.data_vars + assert ds_agged.time.dtype == np.dtype("datetime64[ns]") + + # TODO: Possible property tests to add to this module # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index 15319e2f6c8..e386b96f63d 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -215,7 +215,7 @@ def {method}( function for calculating ``{method}`` on this object's data. These could include dask-specific kwargs like ``split_every``.""" -_NUMERIC_ONLY_NOTES = "Non-numeric variables will be removed prior to reducing." +_NUMERIC_ONLY_NOTES = "Non-numeric variables will be removed prior to reducing. datetime64 and timedelta64 dtypes are treated as numeric for aggregation operations." _FLOX_NOTES_TEMPLATE = """Use the ``flox`` package to significantly speed up {kind} computations, especially with dask arrays. Xarray will use flox by default if installed. From 145a445cca8087fa5ae5f269552b9edbb338bbc4 Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Fri, 12 Sep 2025 23:20:58 +0200 Subject: [PATCH 26/37] TYP: explicit `DTypeLike | None` (#10738) * TYP: explicit `DTypeLike | None` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/backends/common.py | 6 +++--- xarray/coding/common.py | 4 ++-- xarray/coding/variables.py | 6 ++++-- xarray/core/accessor_dt.py | 2 +- xarray/core/accessor_str.py | 4 ++-- xarray/core/common.py | 6 +++--- xarray/core/datatree.py | 2 +- xarray/core/dtypes.py | 2 +- xarray/core/extension_array.py | 2 +- xarray/core/groupby.py | 2 +- xarray/core/indexing.py | 4 ++-- xarray/namedarray/dtypes.py | 2 +- xarray/tests/arrays.py | 6 +++--- xarray/tests/test_assertions.py | 6 +++++- xarray/tests/test_coding_times.py | 4 ++-- xarray/tests/test_formatting.py | 6 +++++- xarray/tests/test_namedarray.py | 4 ++-- xarray/tests/test_variable.py | 2 +- 18 files changed, 40 insertions(+), 30 deletions(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index a725fa27b70..7f6921ae2a1 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -309,11 +309,11 @@ class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): async def async_getitem(self, key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: raise NotImplementedError("Backend does not support asynchronous loading") - def get_duck_array(self, dtype: np.typing.DTypeLike = None): + def get_duck_array(self, dtype: np.typing.DTypeLike | None = None): key = indexing.BasicIndexer((slice(None),) * self.ndim) return self[key] # type: ignore[index] - async def async_get_duck_array(self, dtype: np.typing.DTypeLike = None): + async def async_get_duck_array(self, dtype: np.typing.DTypeLike | None = None): key = indexing.BasicIndexer((slice(None),) * self.ndim) return await self.async_getitem(key) @@ -644,7 +644,7 @@ def _infer_dtype(array, name=None): ) -def _copy_with_dtype(data, dtype: np.typing.DTypeLike): +def _copy_with_dtype(data, dtype: np.typing.DTypeLike | None): """Create a copy of an array with the given dtype. We use this instead of np.array() to ensure that custom object dtypes end diff --git a/xarray/coding/common.py b/xarray/coding/common.py index 79e5e7502b3..a624c8fa57e 100644 --- a/xarray/coding/common.py +++ b/xarray/coding/common.py @@ -53,7 +53,7 @@ class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin): Values are computed upon indexing or coercion to a NumPy array. """ - def __init__(self, array, func: Callable, dtype: np.typing.DTypeLike): + def __init__(self, array, func: Callable, dtype: np.typing.DTypeLike | None): assert not is_chunked_array(array) self.array = indexing.as_indexable(array) self.func = func @@ -86,7 +86,7 @@ def __repr__(self) -> str: return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})" -def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike): +def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike | None): """Lazily apply an element-wise function to an array. Parameters ---------- diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 1a1bb7c03db..e5466d6cb6e 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -122,7 +122,7 @@ def _apply_mask( data: np.ndarray, encoded_fill_values: list, decoded_fill_value: Any, - dtype: np.typing.DTypeLike, + dtype: np.typing.DTypeLike | None, ) -> np.ndarray: """Mask all matching values in a NumPy arrays.""" data = np.asarray(data, dtype=dtype) @@ -426,7 +426,9 @@ def decode(self, variable: Variable, name: T_Name = None): return Variable(dims, data, attrs, encoding, fastpath=True) -def _scale_offset_decoding(data, scale_factor, add_offset, dtype: np.typing.DTypeLike): +def _scale_offset_decoding( + data, scale_factor, add_offset, dtype: np.typing.DTypeLike | None +): data = data.astype(dtype=dtype, copy=True) if scale_factor is not None: data *= scale_factor diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index f91bdd11f5d..e6e75fb3213 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -244,7 +244,7 @@ class TimeAccessor(Generic[T_DataArray]): def __init__(self, obj: T_DataArray) -> None: self._obj = obj - def _date_field(self, name: str, dtype: DTypeLike) -> T_DataArray: + def _date_field(self, name: str, dtype: DTypeLike | None) -> T_DataArray: if dtype is None: dtype = self._obj.dtype result = _get_date_field(_index_or_data(self._obj), name, dtype) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index f16dbe02f32..a632d2729c8 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -112,7 +112,7 @@ def _apply_str_ufunc( *, func: Callable, obj: Any, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, output_core_dims: list | tuple = ((),), output_sizes: Mapping[Any, int] | None = None, func_args: tuple = (), @@ -224,7 +224,7 @@ def _apply( self, *, func: Callable, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, output_core_dims: list | tuple = ((),), output_sizes: Mapping[Any, int] | None = None, func_args: tuple = (), diff --git a/xarray/core/common.py b/xarray/core/common.py index 8f789c35445..b4a2dc1104f 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -164,7 +164,7 @@ def __complex__(self: Any) -> complex: return complex(self.values) def __array__( - self: Any, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self: Any, dtype: DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: if not copy: if np.lib.NumpyVersion(np.__version__) >= "2.0.0": @@ -2073,12 +2073,12 @@ def get_chunksizes( return Frozen(chunks) -def is_np_datetime_like(dtype: DTypeLike) -> bool: +def is_np_datetime_like(dtype: DTypeLike | None) -> bool: """Check if a dtype is a subclass of the numpy datetime types""" return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) -def is_np_timedelta_like(dtype: DTypeLike) -> bool: +def is_np_timedelta_like(dtype: DTypeLike | None) -> bool: """Check whether dtype is of the timedelta64 dtype.""" return np.issubdtype(dtype, np.timedelta64) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 5fe1362c3c6..99441b1a8d4 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -818,7 +818,7 @@ def __iter__(self) -> Iterator[str]: return itertools.chain(self._data_variables, self._children) # type: ignore[arg-type] def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: raise TypeError( "cannot directly convert a DataTree into a " diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 0a7b1722877..6ddae75e9cf 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -238,7 +238,7 @@ def preprocess_types(t): def result_type( - *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike | None, xp=None, ) -> np.dtype: """Like np.result_type, but with type promotion rules matching pandas. diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 9262982d4cb..8a809759a6f 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -160,7 +160,7 @@ def ndim(self) -> int: return 1 def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: if Version(np.__version__) >= Version("2.0.0"): return np.asarray(self.array, dtype=dtype, copy=copy) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 723a89ea3b7..57802c199e1 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -211,7 +211,7 @@ def data(self) -> np.ndarray: return np.arange(self.size, dtype=int) def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: if copy is False: raise NotImplementedError(f"An array copy is necessary, got {copy = }.") diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 94c5aa7ac49..a86b7f56606 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -570,7 +570,7 @@ class ExplicitlyIndexed: __slots__ = () def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: # Leave casting to an array up to the underlying array type. if Version(np.__version__) >= Version("2.0.0"): @@ -653,7 +653,7 @@ def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer): self.indexer_cls = indexer_cls def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: if Version(np.__version__) >= Version("2.0.0"): return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy) diff --git a/xarray/namedarray/dtypes.py b/xarray/namedarray/dtypes.py index a49f7686179..b4a3d6a518e 100644 --- a/xarray/namedarray/dtypes.py +++ b/xarray/namedarray/dtypes.py @@ -165,7 +165,7 @@ def is_datetime_like( def result_type( - *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike | None, ) -> np.dtype[np.generic]: """Like np.result_type, but with type promotion rules matching pandas. diff --git a/xarray/tests/arrays.py b/xarray/tests/arrays.py index 4ee415619ab..d136d4fe752 100644 --- a/xarray/tests/arrays.py +++ b/xarray/tests/arrays.py @@ -25,7 +25,7 @@ def get_duck_array(self): raise UnexpectedDataAccess("Tried accessing data") def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: raise UnexpectedDataAccess("Tried accessing data") @@ -56,7 +56,7 @@ def to_numpy(self) -> np.ndarray: return self.array def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: raise UnexpectedDataAccess("Tried accessing data") @@ -169,7 +169,7 @@ def get_duck_array(self): raise UnexpectedDataAccess("Tried accessing data") def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: raise UnexpectedDataAccess("Tried accessing data") diff --git a/xarray/tests/test_assertions.py b/xarray/tests/test_assertions.py index a0a2c02d578..222a01a6628 100644 --- a/xarray/tests/test_assertions.py +++ b/xarray/tests/test_assertions.py @@ -192,7 +192,11 @@ def dims(self): return super().dims def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: np.typing.DTypeLike | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: warnings.warn("warning in test", stacklevel=2) return super().__array__(dtype, copy=copy) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 103e761ae05..730d6f1dfee 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1383,7 +1383,7 @@ def test_contains_cftime_lazy() -> None: def test_roundtrip_datetime64_nanosecond_precision( timestr: str, format: Literal["ns", "us"], - dtype: np.typing.DTypeLike, + dtype: np.typing.DTypeLike | None, fill_value: int | float | None, use_encoding: bool, time_unit: PDDatetimeUnitOptions, @@ -1499,7 +1499,7 @@ def test_roundtrip_datetime64_nanosecond_precision_warning( [(np.int64, 20), (np.int64, np.iinfo(np.int64).min), (np.float64, 1e30)], ) def test_roundtrip_timedelta64_nanosecond_precision( - dtype: np.typing.DTypeLike, + dtype: np.typing.DTypeLike | None, fill_value: int | float, time_unit: PDDatetimeUnitOptions, ) -> None: diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index c2ab1144e7b..7530bd859a0 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -963,7 +963,11 @@ def test_lazy_array_wont_compute() -> None: class LazilyIndexedArrayNotComputable(LazilyIndexedArray): def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: np.typing.DTypeLike | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: raise NotImplementedError("Computing this array is not possible.") diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 537cd824767..ce825e5bb81 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -55,7 +55,7 @@ class CustomArray( CustomArrayBase[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] ): def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, dtype: DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray[Any, np.dtype[np.generic]]: if Version(np.__version__) >= Version("2.0.0"): return np.asarray(self.array, dtype=dtype, copy=copy) @@ -292,7 +292,7 @@ def test_real_and_imag(self) -> None: (b"foo", np.dtype("S3")), ], ) - def test_from_array_0d_string(self, data: Any, dtype: DTypeLike) -> None: + def test_from_array_0d_string(self, data: Any, dtype: DTypeLike | None) -> None: named_array: NamedArray[Any, Any] named_array = from_array([], data) assert named_array.data == data diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 41475a0be6e..8b3d43fb3d6 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -337,7 +337,7 @@ def test_pandas_period_index(self): assert "PeriodArray" in repr(v) @pytest.mark.parametrize("dtype", [float, int]) - def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: + def test_1d_math(self, dtype: np.typing.DTypeLike | None) -> None: x = np.arange(5, dtype=dtype) y = np.ones(5, dtype=dtype) From bc11b60d1c3b24e29a07e19cc120e3c0da28475c Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 14 Sep 2025 17:24:19 -0700 Subject: [PATCH 27/37] refine Claude instructions (#10744) --- CLAUDE.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index e55dc61f412..c8cfe1185ce 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -22,7 +22,9 @@ uv run dmypy run # Type checking with mypy ## GitHub Interaction Guidelines -- **NEVER impersonate the user on GitHub** - Do not post comments, create issues, or interact with the xarray GitHub repository unless explicitly instructed -- Never create GitHub issues or PRs unless explicitly requested by the user -- Never post "update" messages, progress reports, or explanatory comments on GitHub issues/PRs unless specifically asked -- Always require explicit user direction before creating pull requests or pushing to the xarray GitHub repository +- **NEVER impersonate the user on GitHub**, always sign off with something like + "[This is Claude Code on behalf of Jane Doe]" +- Never create issues nor pull requests on the xarray GitHub repository unless + explicitly instructed +- Never post "update" messages, progress reports, or explanatory comments on + GitHub issues/PRs unless specifically instructed From 917f2380c8a7c7801cad22d75d73e28fc21a6f80 Mon Sep 17 00:00:00 2001 From: gronniger <50588526+gronniger@users.noreply.github.com> Date: Mon, 15 Sep 2025 18:16:11 +0200 Subject: [PATCH 28/37] fix bugreport.yml (#10746) * fix bugreport.yml * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/bugreport.yml | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index dca031aca78..a25d2fb8379 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -29,16 +29,11 @@ body: Minimal, self-contained copy-pastable example that demonstrates the issue. This will be automatically formatted into code, so no need for markdown backticks. render: Python - - type: checkboxes - id: mvce-checkboxes + - type: textarea + id: reproduce attributes: - label: MVCE confirmation - description: | - Please confirm that the bug report is in an excellent state, so we can understand & fix it quickly & efficiently. For more details, check out: - - - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) - - [Craft Minimal Bug Reports](https://matthewrocklin.com/minimal-bug-reports) - + label: Steps to reproduce + description: Consider listing additional or specific dependencies in [inline script metadata](https://packaging.python.org/en/latest/specifications/inline-script-metadata/#example) so that calling `uv run issue.py` shows the issue when copied into `issue.py`. (not strictly required) value: | @@ -56,6 +51,19 @@ body: xr.show_versions() # your reproducer code ... ``` + validations: + required: false + + - type: checkboxes + id: mvce-checkboxes + attributes: + label: MVCE confirmation + description: | + Please confirm that the bug report is in an excellent state, so we can understand & fix it quickly & efficiently. For more details, check out: + + - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) + - [Craft Minimal Bug Reports](https://matthewrocklin.com/minimal-bug-reports) + options: - label: Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray. - label: Complete example — the example is self-contained, including all data and the text of any traceback. From d9047c235f2a4b95102094864f814b8f8e219327 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 15 Sep 2025 09:39:40 -0700 Subject: [PATCH 29/37] Add co-authorship instruction to CLAUDE.md (#10748) Co-authored-by: Claude --- CLAUDE.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index c8cfe1185ce..1781aeed1cc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,3 +28,5 @@ uv run dmypy run # Type checking with mypy explicitly instructed - Never post "update" messages, progress reports, or explanatory comments on GitHub issues/PRs unless specifically instructed +- When creating commits, always include a co-authorship trailer: + `Co-authored-by: Claude ` From 2a020c0be683773d9c4ade141721f47a00667110 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 15 Sep 2025 09:50:45 -0700 Subject: [PATCH 30/37] Revert "fix bugreport.yml (#10746)" (#10749) This reverts commit 2b52b956f7123f480bfd1bfe0112b9a24df38b54. --- .github/ISSUE_TEMPLATE/bugreport.yml | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index a25d2fb8379..dca031aca78 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -29,11 +29,16 @@ body: Minimal, self-contained copy-pastable example that demonstrates the issue. This will be automatically formatted into code, so no need for markdown backticks. render: Python - - type: textarea - id: reproduce + - type: checkboxes + id: mvce-checkboxes attributes: - label: Steps to reproduce - description: + label: MVCE confirmation + description: | + Please confirm that the bug report is in an excellent state, so we can understand & fix it quickly & efficiently. For more details, check out: + + - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) + - [Craft Minimal Bug Reports](https://matthewrocklin.com/minimal-bug-reports) + Consider listing additional or specific dependencies in [inline script metadata](https://packaging.python.org/en/latest/specifications/inline-script-metadata/#example) so that calling `uv run issue.py` shows the issue when copied into `issue.py`. (not strictly required) value: | @@ -51,19 +56,6 @@ body: xr.show_versions() # your reproducer code ... ``` - validations: - required: false - - - type: checkboxes - id: mvce-checkboxes - attributes: - label: MVCE confirmation - description: | - Please confirm that the bug report is in an excellent state, so we can understand & fix it quickly & efficiently. For more details, check out: - - - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) - - [Craft Minimal Bug Reports](https://matthewrocklin.com/minimal-bug-reports) - options: - label: Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray. - label: Complete example — the example is self-contained, including all data and the text of any traceback. From a5f33e80e9b7ad1a856a8238bfefb3dfc8577480 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 15 Sep 2025 10:02:19 -0700 Subject: [PATCH 31/37] Fix to_netcdf(compute=False) with Dask distributed (#10730) * Fix to_netcdf(compute=False) with Dask distributed Fixes #10725 * Silence incorrect mypy error --- doc/whats-new.rst | 3 +++ xarray/backends/api.py | 27 ++++++++++++++++----------- xarray/tests/test_distributed.py | 8 +++++++- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5d4d9896180..60f03094c5c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -48,6 +48,9 @@ Bug fixes - Fix error when encoding an empty :py:class:`numpy.datetime64` array (:issue:`10722`, :pull:`10723`). By `Spencer Clark `_. +- Fix error from ``to_netcdf(..., compute=False)`` when using Dask Distributed + (:issue:`10725`). + By `Stephan Hoyer `_. - Propagation coordinate attrs in :py:meth:`xarray.Dataset.map` (:issue:`9317`, :pull:`10602`). By `Justus Magin `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index aa6fddd89e1..7d23837da35 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1849,6 +1849,20 @@ def open_mfdataset( return combined +def _get_netcdf_autoclose(dataset: Dataset, engine: T_NetcdfEngine) -> bool: + """Should we close files after each write operations?""" + scheduler = get_dask_scheduler() + have_chunks = any(v.chunks is not None for v in dataset.variables.values()) + + autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] + if autoclose and engine == "scipy": + raise NotImplementedError( + f"Writing netCDF files with the {engine} backend " + f"is not currently supported with dask's {scheduler} scheduler" + ) + return autoclose + + WRITEABLE_STORES: dict[T_NetcdfEngine, Callable] = { "netcdf4": backends.NetCDF4DataStore.open, "scipy": backends.ScipyDataStore, @@ -2055,16 +2069,7 @@ def to_netcdf( # sanitize unlimited_dims unlimited_dims = _sanitize_unlimited_dims(dataset, unlimited_dims) - # handle scheduler specific logic - scheduler = get_dask_scheduler() - have_chunks = any(v.chunks is not None for v in dataset.variables.values()) - - autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] - if autoclose and engine == "scipy": - raise NotImplementedError( - f"Writing netCDF files with the {engine} backend " - f"is not currently supported with dask's {scheduler} scheduler" - ) + autoclose = _get_netcdf_autoclose(dataset, engine) if path_or_file is None: if not compute: @@ -2107,7 +2112,7 @@ def to_netcdf( writes = writer.sync(compute=compute) finally: - if not multifile: + if not multifile and not autoclose: # type: ignore[redundant-expr,unused-ignore] if compute: store.close() else: diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 9ae83bc2664..db17a2c13df 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -85,11 +85,13 @@ def tmp_netcdf_filename(tmpdir): @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) +@pytest.mark.parametrize("compute", [True, False]) def test_dask_distributed_netcdf_roundtrip( loop, # noqa: F811 tmp_netcdf_filename, engine, nc_format, + compute, ): if engine not in ENGINES: pytest.skip("engine not available") @@ -107,7 +109,11 @@ def test_dask_distributed_netcdf_roundtrip( ) return - original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) + result = original.to_netcdf( + tmp_netcdf_filename, engine=engine, format=nc_format, compute=compute + ) + if not compute: + result.compute() with xr.open_dataset( tmp_netcdf_filename, chunks=chunks, engine=engine From f86afc381d61863d71a30226dc7b72444423efd1 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Mon, 15 Sep 2025 13:26:58 -0400 Subject: [PATCH 32/37] cleanup bug report template (#10752) * cleanup bug report template * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/bugreport.yml | 37 +++++++++++++++++----------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index dca031aca78..86ab48f224f 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -26,23 +26,14 @@ body: attributes: label: Minimal Complete Verifiable Example description: | - Minimal, self-contained copy-pastable example that demonstrates the issue. This will be automatically formatted into code, so no need for markdown backticks. - render: Python - - - type: checkboxes - id: mvce-checkboxes - attributes: - label: MVCE confirmation - description: | - Please confirm that the bug report is in an excellent state, so we can understand & fix it quickly & efficiently. For more details, check out: - - - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) - - [Craft Minimal Bug Reports](https://matthewrocklin.com/minimal-bug-reports) + Minimal, self-contained copy-pastable example that demonstrates the issue. Consider listing additional or specific dependencies in [inline script metadata](https://packaging.python.org/en/latest/specifications/inline-script-metadata/#example) so that calling `uv run issue.py` shows the issue when copied into `issue.py`. (not strictly required) + + This will be automatically formatted into code, so no need for markdown backticks. + render: Python value: | - ```python # /// script # requires-python = ">=3.11" # dependencies = [ @@ -55,7 +46,25 @@ body: import xarray as xr xr.show_versions() # your reproducer code ... - ``` + + - type: textarea + id: reproduce + attributes: + label: Steps to reproduce + description: + validations: + required: false + + - type: checkboxes + id: mvce-checkboxes + attributes: + label: MVCE confirmation + description: | + Please confirm that the bug report is in an excellent state, so we can understand & fix it quickly & efficiently. For more details, check out: + + - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) + - [Craft Minimal Bug Reports](https://matthewrocklin.com/minimal-bug-reports) + options: - label: Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray. - label: Complete example — the example is self-contained, including all data and the text of any traceback. From da93f8cfc59870515a119011a2368acc8bf86041 Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Mon, 15 Sep 2025 19:30:11 +0200 Subject: [PATCH 33/37] Ensure that groupby.groups works even with multiple groupers (#10750) --- xarray/core/groupby.py | 2 +- xarray/tests/test_groupby.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 57802c199e1..a5482c5a514 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -828,7 +828,7 @@ def groups(self) -> dict[GroupKey, GroupIndex]: self._groups = dict( zip( self.encoded.unique_coord.data, - self.encoded.group_indices, + tuple(g for g in self.encoded.group_indices if g), strict=True, ) ) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 202b729abcf..336a5e6c91c 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3200,6 +3200,16 @@ def test_multiple_grouper_unsorted_order() -> None: assert_identical(actual2, expected2) +def test_multiple_grouper_empty_groups() -> None: + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 3))}, + coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, + ) + + groups = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()) + assert len(groups.groups) == 2 + + def test_groupby_multiple_bin_grouper_missing_groups() -> None: from numpy import nan From 74a075b17bd2291c3375478030fb1925bf3f34e8 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 15 Sep 2025 11:03:52 -0700 Subject: [PATCH 34/37] Remove xarray.tests.test_combine from mypy exclusions (#10753) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add type annotations to satisfy mypy's check_untyped_defs requirement - Fix combine_nested type signature to accept None for concat_dim parameter - Add minimal type hints where mypy cannot infer types - No behavioral changes, all tests pass 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude --- pyproject.toml | 1 - xarray/structure/combine.py | 2 +- xarray/tests/test_combine.py | 56 ++++++++++++++++++++++++------------ 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a71ac7de00..0859b66df55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,7 +190,6 @@ check_untyped_defs = false module = [ "xarray.tests.test_coarsen", "xarray.tests.test_coding_times", - "xarray.tests.test_combine", "xarray.tests.test_computation", "xarray.tests.test_concat", "xarray.tests.test_coordinates", diff --git a/xarray/structure/combine.py b/xarray/structure/combine.py index b1dbc8e4c0e..9a0aadbf730 100644 --- a/xarray/structure/combine.py +++ b/xarray/structure/combine.py @@ -405,7 +405,7 @@ def _nested_combine( def combine_nested( datasets: DATASET_HYPERCUBE, - concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None], + concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None, compat: str | CombineKwargDefault = _COMPAT_DEFAULT, data_vars: str | CombineKwargDefault = _DATA_VARS_DEFAULT, coords: str | CombineKwargDefault = _COORDS_DEFAULT, diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 3b4ad097795..c7c2a60010f 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -28,7 +28,7 @@ from xarray.tests.test_dataset import create_test_data -def assert_combined_tile_ids_equal(dict1, dict2): +def assert_combined_tile_ids_equal(dict1: dict, dict2: dict) -> None: assert len(dict1) == len(dict2) for k in dict1.keys(): assert k in dict2.keys() @@ -41,7 +41,9 @@ def test_1d(self): input = [ds(0), ds(1)] expected = {(0,): ds(0), (1,): ds(1)} - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) def test_2d(self): @@ -56,7 +58,9 @@ def test_2d(self): (2, 0): ds(4), (2, 1): ds(5), } - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) def test_3d(self): @@ -80,7 +84,9 @@ def test_3d(self): (1, 2, 0): ds(10), (1, 2, 1): ds(11), } - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) def test_single_dataset(self): @@ -88,7 +94,9 @@ def test_single_dataset(self): input = [ds] expected = {(0,): ds} - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) def test_redundant_nesting(self): @@ -96,24 +104,30 @@ def test_redundant_nesting(self): input = [[ds(0)], [ds(1)]] expected = {(0, 0): ds(0), (1, 0): ds(1)} - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) def test_ignore_empty_list(self): ds = create_test_data(0) - input = [ds, []] + input: list = [ds, []] expected = {(0,): ds} - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) def test_uneven_depth_input(self): # Auto_combine won't work on ragged input # but this is just to increase test coverage ds = create_test_data - input = [ds(0), [ds(1), ds(2)]] + input: list = [ds(0), [ds(1), ds(2)]] expected = {(0,): ds(0), (1, 0): ds(1), (1, 1): ds(2)} - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) def test_uneven_length_input(self): @@ -123,7 +137,9 @@ def test_uneven_length_input(self): input = [[ds(0)], [ds(1), ds(2)]] expected = {(0, 0): ds(0), (1, 0): ds(1), (1, 1): ds(2)} - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) def test_infer_from_datasets(self): @@ -131,7 +147,9 @@ def test_infer_from_datasets(self): input = [ds(0), ds(1)] expected = {(0,): ds(0), (1,): ds(1)} - actual = _infer_concat_order_from_positions(input) + actual: dict[tuple[int, ...], Dataset] = _infer_concat_order_from_positions( + input + ) assert_combined_tile_ids_equal(expected, actual) @@ -581,8 +599,8 @@ def test_auto_combine_2d_combine_attrs_kwarg(self): expected_dict["override"] = expected.copy(deep=True) expected_dict["override"].attrs = {"a": 1} f = lambda attrs, context: attrs[0] - expected_dict[f] = expected.copy(deep=True) - expected_dict[f].attrs = f([{"a": 1}], None) + expected_dict[f] = expected.copy(deep=True) # type: ignore[index] + expected_dict[f].attrs = f([{"a": 1}], None) # type: ignore[index] datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4), ds(5)]] @@ -606,7 +624,7 @@ def test_auto_combine_2d_combine_attrs_kwarg(self): datasets, concat_dim=["dim1", "dim2"], data_vars="all", - combine_attrs=combine_attrs, + combine_attrs=combine_attrs, # type: ignore[arg-type] ) assert_identical(result, expected) @@ -632,11 +650,11 @@ def test_invalid_hypercube_input(self): ): combine_nested(datasets, concat_dim=["dim1", "dim2"]) - datasets = [[ds(0), ds(1)], [[ds(3), ds(4)]]] + datasets2: list = [[ds(0), ds(1)], [[ds(3), ds(4)]]] with pytest.raises( ValueError, match=r"sub-lists do not have consistent depths" ): - combine_nested(datasets, concat_dim=["dim1", "dim2"]) + combine_nested(datasets2, concat_dim=["dim1", "dim2"]) datasets = [[ds(0), ds(1)], [ds(3), ds(4)]] with pytest.raises(ValueError, match=r"concat_dims has length"): @@ -1019,7 +1037,7 @@ def test_infer_order_from_coords(self): objs = [data.isel(dim2=slice(4, 9)), data.isel(dim2=slice(4))] actual = combine_by_coords(objs, data_vars="all") expected = data - assert expected.broadcast_equals(actual) + assert expected.broadcast_equals(actual) # type: ignore[arg-type] with set_options(use_new_combine_kwarg_defaults=True): actual = combine_by_coords(objs) @@ -1067,7 +1085,7 @@ def test_combine_by_coords_still_fails(self): # https://github.com/pydata/xarray/issues/508 datasets = [Dataset({"x": 0}, {"y": 0}), Dataset({"x": 1}, {"y": 1, "z": 1})] with pytest.raises(ValueError): - combine_by_coords(datasets, "y") + combine_by_coords(datasets, "y") # type: ignore[arg-type] def test_combine_by_coords_no_concat(self): objs = [Dataset({"x": 0}), Dataset({"y": 1})] From c3653cebbdc100ffc636259965e37c73f9470e6d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 15 Sep 2025 12:05:37 -0700 Subject: [PATCH 35/37] Synchronize mypy test requirements to 1.17.1 (#10751) * Update mypy test requirement to 1.17.1 * Fix mypy 1.17.1 compatibility issues - Add list-item to type ignore for df.columns assignment - Remove unnecessary attr-defined type ignores that mypy 1.17.1 now understands * Restore type ignores needed by mypy 1.17.1 in CI environment --------- Co-authored-by: Maximilian Roos Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- .github/workflows/ci-additional.yaml | 4 ++-- xarray/tests/test_dataset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index d744fcafc6d..2cd28506a05 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -116,7 +116,7 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install "mypy==1.15" --force-reinstall + python -m pip install "mypy==1.17.1" --force-reinstall - name: Run mypy run: | @@ -167,7 +167,7 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install "mypy==1.15" --force-reinstall + python -m pip install "mypy==1.17.1" --force-reinstall - name: Run mypy run: | diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 959177dec68..355dae5cbee 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5303,7 +5303,7 @@ def test_from_dataframe_unsorted_levels(self) -> None: def test_from_dataframe_non_unique_columns(self) -> None: # regression test for GH449 df = pd.DataFrame(np.zeros((2, 2))) - df.columns = ["foo", "foo"] # type: ignore[assignment,unused-ignore] + df.columns = ["foo", "foo"] # type: ignore[assignment,list-item,unused-ignore] with pytest.raises(ValueError, match=r"non-unique columns"): Dataset.from_dataframe(df) From 2d56f6b6916a90469afb5c4bbbc9a28a062ecc1f Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Tue, 16 Sep 2025 00:08:48 +0200 Subject: [PATCH 36/37] Multiple imports for an import name (#10743) --- xarray/backends/api.py | 1 - xarray/computation/apply_ufunc.py | 2 +- xarray/core/dataset.py | 2 -- xarray/core/groupby.py | 1 - xarray/groupers.py | 7 ------- xarray/structure/alignment.py | 1 - xarray/tests/test_backends.py | 7 +------ 7 files changed, 2 insertions(+), 19 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7d23837da35..bba0d06fe96 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -77,7 +77,6 @@ ErrorOptionsWithWarn, JoinOptions, NestedSequence, - ReadBuffer, T_Chunks, ZarrStoreLike, ) diff --git a/xarray/computation/apply_ufunc.py b/xarray/computation/apply_ufunc.py index 00a06e12d63..e205059841d 100644 --- a/xarray/computation/apply_ufunc.py +++ b/xarray/computation/apply_ufunc.py @@ -715,7 +715,7 @@ def apply_variable_ufunc( ) -> Variable | tuple[Variable, ...]: """Apply a ndarray level function over Variable and/or ndarray objects.""" from xarray.core.formatting import short_array_repr - from xarray.core.variable import Variable, as_compatible_data + from xarray.core.variable import as_compatible_data dim_sizes = unified_dim_sizes( (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 377f6db28f3..2310d391e67 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8465,8 +8465,6 @@ def integrate( return result def _integrate_one(self, coord, datetime_unit=None, cumulative=False): - from xarray.core.variable import Variable - if coord not in self.variables and coord not in self.dims: variables_and_dims = tuple(set(self.variables.keys()).union(self.dims)) raise ValueError( diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a5482c5a514..a986bd9d937 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -390,7 +390,6 @@ def _parse_group_and_groupers( eagerly_compute_group: Literal[False] | None, ) -> tuple[ResolvedGrouper, ...]: from xarray.core.dataarray import DataArray - from xarray.core.variable import Variable from xarray.groupers import Grouper, UniqueGrouper if group is not None and groupers: diff --git a/xarray/groupers.py b/xarray/groupers.py index 0f2ec8ac4c1..a16933e690f 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -547,9 +547,6 @@ def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: return full_index, first_items, codes def first_items(self) -> tuple[pd.Series, np.ndarray]: - from xarray.coding.cftimeindex import CFTimeIndex - from xarray.core.resample_cftime import CFTimeGrouper - if isinstance(self.index_grouper, CFTimeGrouper): return self.index_grouper.first_items( cast(CFTimeIndex, self.group_as_index) @@ -605,8 +602,6 @@ def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ... tuple[int, ...] A tuple of chunk sizes for the dimension. """ - from xarray.core.dataarray import DataArray - if not _contains_datetime_like_objects(variable): raise ValueError( f"Computing chunks with {type(self)!r} only supported for datetime variables. " @@ -1050,8 +1045,6 @@ def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ... tuple[int, ...] A tuple of chunk sizes for the dimension. """ - from xarray.core.dataarray import DataArray - if not _contains_datetime_like_objects(variable): raise ValueError( f"Computing chunks with {type(self)!r} only supported for datetime variables. " diff --git a/xarray/structure/alignment.py b/xarray/structure/alignment.py index a9c0832de36..f5d61e63409 100644 --- a/xarray/structure/alignment.py +++ b/xarray/structure/alignment.py @@ -30,7 +30,6 @@ from xarray.core.dataset import Dataset from xarray.core.types import ( Alignable, - JoinOptions, T_DataArray, T_Dataset, T_DuckArray, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 1b687086659..715314daddd 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -193,7 +193,6 @@ def _check_compression_codec_available(codec: str | None) -> bool: try: import os - import tempfile import netCDF4 @@ -5361,11 +5360,9 @@ def test_open_mfdataset_list_attr() -> None: """ Case when an attribute of type list differs across the multiple files """ - from netCDF4 import Dataset - with create_tmp_files(2) as nfiles: for i in range(2): - with Dataset(nfiles[i], "w") as f: + with nc4.Dataset(nfiles[i], "w") as f: f.createDimension("x", 3) vlvar = f.createVariable("test_var", np.int32, ("x")) # here create an attribute as a list @@ -7332,8 +7329,6 @@ def test_zarr_closing_internal_zip_store(): @requires_zarr @pytest.mark.parametrize("create_default_indexes", [True, False]) def test_zarr_create_default_indexes(tmp_path, create_default_indexes) -> None: - from xarray.core.indexes import PandasIndex - store_path = tmp_path / "tmp.zarr" original_ds = xr.Dataset({"data": ("x", np.arange(3))}, coords={"x": [-1, 0, 1]}) original_ds.to_zarr(store_path, mode="w") From 8f9363e3eb3f2216b7fb06aee724d0fb83f7ed57 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 15 Sep 2025 18:04:21 -0700 Subject: [PATCH 37/37] Add note about breaking change --- doc/whats-new.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 60f03094c5c..a0c4c2bfb4f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,11 @@ Breaking changes dataset in-place. (:issue:`10167`) By `Maximilian Roos `_. +- The default ``engine`` when reading/writing netCDF files in-memory is now + netCDF4, consistent with Xarray's default ``engine`` when read/writing netCDF + files to disk (:pull:`10624`). + By `Stephan Hoyer `_. + Deprecations ~~~~~~~~~~~~