diff --git a/doc/whats-new.rst b/doc/whats-new.rst index acb81f3692a..a0c4c2bfb4f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,10 @@ v2025.09.1 (unreleased) New Features ~~~~~~~~~~~~ +- ``engine='netcdf4'`` now supports reading and writing in-memory netCDF files. + All of Xarray's netCDF backends now support in-memory reads and writes + (:pull:`10624`). + By `Stephan Hoyer `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -22,6 +26,11 @@ Breaking changes dataset in-place. (:issue:`10167`) By `Maximilian Roos `_. +- The default ``engine`` when reading/writing netCDF files in-memory is now + netCDF4, consistent with Xarray's default ``engine`` when read/writing netCDF + files to disk (:pull:`10624`). + By `Stephan Hoyer `_. + Deprecations ~~~~~~~~~~~~ @@ -29,6 +38,16 @@ Deprecations Bug fixes ~~~~~~~~~ +- Xarray objects opened from file-like objects with ``engine='h5netcdf'`` can + now be pickled, as long as the underlying file-like object also supports + pickle (:issue:`10712`). + By `Stephan Hoyer `_. + +- Closing Xarray objects opened from file-like objects with ```engine='scipy'`` + no longer closes the underlying file, consistent with the h5netcdf backend + (:pull:`10624`). + By `Stephan Hoyer `_. + - Fix the ``align_chunks`` parameter on the :py:meth:`~xarray.Dataset.to_zarr` method, it was not being passed to the underlying :py:meth:`~xarray.backends.api` method (:issue:`10501`, :pull:`10516`). - Fix error when encoding an empty :py:class:`numpy.datetime64` array diff --git a/xarray/backends/api.py b/xarray/backends/api.py index e074225de86..bba0d06fe96 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -99,7 +99,7 @@ def get_default_netcdf_write_engine( format: T_NetcdfTypes | None, - to_fileobject_or_memoryview: bool, + to_fileobject: bool, ) -> Literal["netcdf4", "h5netcdf", "scipy"]: """Return the default netCDF library to use for writing a netCDF file.""" module_names = { @@ -118,7 +118,7 @@ def get_default_netcdf_write_engine( else: raise ValueError(f"unexpected {format=}") - if to_fileobject_or_memoryview: + if to_fileobject: candidates.remove("netcdf4") for engine in candidates: @@ -545,14 +545,12 @@ def open_dataset( cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | Mapping[str, bool] | None = None, - decode_times: bool - | CFDatetimeCoder - | Mapping[str, bool | CFDatetimeCoder] - | None = None, - decode_timedelta: bool - | CFTimedeltaCoder - | Mapping[str, bool | CFTimedeltaCoder] - | None = None, + decode_times: ( + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] | None + ) = None, + decode_timedelta: ( + bool | CFTimedeltaCoder | Mapping[str, bool | CFTimedeltaCoder] | None + ) = None, use_cftime: bool | Mapping[str, bool] | None = None, concat_characters: bool | Mapping[str, bool] | None = None, decode_coords: Literal["coordinates", "all"] | bool | None = None, @@ -788,10 +786,9 @@ def open_dataarray( cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | None = None, - decode_times: bool - | CFDatetimeCoder - | Mapping[str, bool | CFDatetimeCoder] - | None = None, + decode_times: ( + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] | None + ) = None, decode_timedelta: bool | CFTimedeltaCoder | None = None, use_cftime: bool | None = None, concat_characters: bool | None = None, @@ -1018,14 +1015,12 @@ def open_datatree( cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | Mapping[str, bool] | None = None, - decode_times: bool - | CFDatetimeCoder - | Mapping[str, bool | CFDatetimeCoder] - | None = None, - decode_timedelta: bool - | CFTimedeltaCoder - | Mapping[str, bool | CFTimedeltaCoder] - | None = None, + decode_times: ( + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] | None + ) = None, + decode_timedelta: ( + bool | CFTimedeltaCoder | Mapping[str, bool | CFTimedeltaCoder] | None + ) = None, use_cftime: bool | Mapping[str, bool] | None = None, concat_characters: bool | Mapping[str, bool] | None = None, decode_coords: Literal["coordinates", "all"] | bool | None = None, @@ -1260,14 +1255,12 @@ def open_groups( cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | Mapping[str, bool] | None = None, - decode_times: bool - | CFDatetimeCoder - | Mapping[str, bool | CFDatetimeCoder] - | None = None, - decode_timedelta: bool - | CFTimedeltaCoder - | Mapping[str, bool | CFTimedeltaCoder] - | None = None, + decode_times: ( + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] | None + ) = None, + decode_timedelta: ( + bool | CFTimedeltaCoder | Mapping[str, bool | CFTimedeltaCoder] | None + ) = None, use_cftime: bool | Mapping[str, bool] | None = None, concat_characters: bool | Mapping[str, bool] | None = None, decode_coords: Literal["coordinates", "all"] | bool | None = None, @@ -1523,10 +1516,9 @@ def _remove_path( def open_mfdataset( - paths: str - | os.PathLike - | ReadBuffer - | NestedSequence[str | os.PathLike | ReadBuffer], + paths: ( + str | os.PathLike | ReadBuffer | NestedSequence[str | os.PathLike | ReadBuffer] + ), chunks: T_Chunks = None, concat_dim: ( str @@ -1540,10 +1532,9 @@ def open_mfdataset( compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT, preprocess: Callable[[Dataset], Dataset] | None = None, engine: T_Engine = None, - data_vars: Literal["all", "minimal", "different"] - | None - | list[str] - | CombineKwargDefault = _DATA_VARS_DEFAULT, + data_vars: ( + Literal["all", "minimal", "different"] | None | list[str] | CombineKwargDefault + ) = _DATA_VARS_DEFAULT, coords=_COORDS_DEFAULT, combine: Literal["by_coords", "nested"] = "by_coords", parallel: bool = False, @@ -2068,8 +2059,8 @@ def to_netcdf( path_or_file = _normalize_path(path_or_file) if engine is None: - to_fileobject_or_memoryview = not isinstance(path_or_file, str) - engine = get_default_netcdf_write_engine(format, to_fileobject_or_memoryview) + to_fileobject = isinstance(path_or_file, IOBase) + engine = get_default_netcdf_write_engine(format, to_fileobject) # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 2a6f3691faf..f7cd4675729 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -1,31 +1,33 @@ from __future__ import annotations import atexit -import contextlib -import io import threading import uuid import warnings -from collections.abc import Hashable -from typing import Any +from collections.abc import Callable, Hashable, Iterator, Mapping, MutableMapping +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Generic, Literal, TypeVar, cast from xarray.backends.locks import acquire from xarray.backends.lru_cache import LRUCache from xarray.core import utils from xarray.core.options import OPTIONS +from xarray.core.types import Closable, Lock # Global cache for storing open files. -FILE_CACHE: LRUCache[Any, io.IOBase] = LRUCache( +FILE_CACHE: LRUCache[Any, Closable] = LRUCache( maxsize=OPTIONS["file_cache_maxsize"], on_evict=lambda k, v: v.close() ) assert FILE_CACHE.maxsize, "file cache must be at least size one" +T_File = TypeVar("T_File", bound=Closable) + REF_COUNTS: dict[Any, int] = {} -_DEFAULT_MODE = utils.ReprObject("") +_OMIT_MODE = utils.ReprObject("") -class FileManager: +class FileManager(Generic[T_File]): """Manager for acquiring and closing a file object. Use FileManager subclasses (CachingFileManager in particular) on backend @@ -33,11 +35,13 @@ class FileManager: many open files and transferring them between multiple processes. """ - def acquire(self, needs_lock=True): + def acquire(self, needs_lock: bool = True) -> T_File: """Acquire the file object from this manager.""" raise NotImplementedError() - def acquire_context(self, needs_lock=True): + def acquire_context( + self, needs_lock: bool = True + ) -> AbstractContextManager[T_File]: """Context manager for acquiring a file. Yields a file object. The context manager unwinds any actions taken as part of acquisition @@ -46,12 +50,12 @@ def acquire_context(self, needs_lock=True): """ raise NotImplementedError() - def close(self, needs_lock=True): + def close(self, needs_lock: bool = True) -> None: """Close the file object associated with this manager, if needed.""" raise NotImplementedError() -class CachingFileManager(FileManager): +class CachingFileManager(FileManager[T_File]): """Wrapper for automatically opening and closing file objects. Unlike files, CachingFileManager objects can be safely pickled and passed @@ -81,14 +85,14 @@ class CachingFileManager(FileManager): def __init__( self, - opener, - *args, - mode=_DEFAULT_MODE, - kwargs=None, - lock=None, - cache=None, + opener: Callable[..., T_File], + *args: Any, + mode: Any = _OMIT_MODE, + kwargs: Mapping[str, Any] | None = None, + lock: Lock | None | Literal[False] = None, + cache: MutableMapping[Any, T_File] | None = None, manager_id: Hashable | None = None, - ref_counts=None, + ref_counts: dict[Any, int] | None = None, ): """Initialize a CachingFileManager. @@ -134,13 +138,17 @@ def __init__( self._mode = mode self._kwargs = {} if kwargs is None else dict(kwargs) - self._use_default_lock = lock is None or lock is False - self._lock = threading.Lock() if self._use_default_lock else lock + if lock is None or lock is False: + self._use_default_lock = True + self._lock: Lock = threading.Lock() + else: + self._use_default_lock = False + self._lock = lock # cache[self._key] stores the file associated with this object. if cache is None: - cache = FILE_CACHE - self._cache = cache + cache = cast(MutableMapping[Any, T_File], FILE_CACHE) + self._cache: MutableMapping[Any, T_File] = cache if manager_id is None: # Each call to CachingFileManager should separately open files. manager_id = str(uuid.uuid4()) @@ -155,7 +163,7 @@ def __init__( self._ref_counter = _RefCounter(ref_counts) self._ref_counter.increment(self._key) - def _make_key(self): + def _make_key(self) -> _HashedSequence: """Make a key for caching files in the LRU cache.""" value = ( self._opener, @@ -166,8 +174,8 @@ def _make_key(self): ) return _HashedSequence(value) - @contextlib.contextmanager - def _optional_lock(self, needs_lock): + @contextmanager + def _optional_lock(self, needs_lock: bool): """Context manager for optionally acquiring a lock.""" if needs_lock: with self._lock: @@ -175,7 +183,7 @@ def _optional_lock(self, needs_lock): else: yield - def acquire(self, needs_lock=True): + def acquire(self, needs_lock: bool = True) -> T_File: """Acquire a file object from the manager. A new file is only opened if it has expired from the @@ -193,8 +201,8 @@ def acquire(self, needs_lock=True): file, _ = self._acquire_with_cache_info(needs_lock) return file - @contextlib.contextmanager - def acquire_context(self, needs_lock=True): + @contextmanager + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_File]: """Context manager for acquiring a file.""" file, cached = self._acquire_with_cache_info(needs_lock) try: @@ -204,14 +212,14 @@ def acquire_context(self, needs_lock=True): self.close(needs_lock) raise - def _acquire_with_cache_info(self, needs_lock=True): + def _acquire_with_cache_info(self, needs_lock: bool = True) -> tuple[T_File, bool]: """Acquire a file, returning the file and whether it was cached.""" with self._optional_lock(needs_lock): try: file = self._cache[self._key] except KeyError: kwargs = self._kwargs - if self._mode is not _DEFAULT_MODE: + if self._mode is not _OMIT_MODE: kwargs = kwargs.copy() kwargs["mode"] = self._mode file = self._opener(*self._args, **kwargs) @@ -223,7 +231,7 @@ def _acquire_with_cache_info(self, needs_lock=True): else: return file, True - def close(self, needs_lock=True): + def close(self, needs_lock: bool = True) -> None: """Explicitly close any associated file object (if necessary).""" # TODO: remove needs_lock if/when we have a reentrant lock in # dask.distributed: https://github.com/dask/dask/issues/3832 @@ -282,7 +290,7 @@ def __setstate__(self, state) -> None: def __repr__(self) -> str: args_string = ", ".join(map(repr, self._args)) - if self._mode is not _DEFAULT_MODE: + if self._mode is not _OMIT_MODE: args_string += f", mode={self._mode!r}" return ( f"{type(self).__name__}({self._opener!r}, {args_string}, " @@ -290,13 +298,6 @@ def __repr__(self) -> str: ) -@atexit.register -def _remove_del_method(): - # We don't need to close unclosed files at program exit, and may not be able - # to, because Python is cleaning up imports / globals. - del CachingFileManager.__del__ - - class _RefCounter: """Class for keeping track of reference counts.""" @@ -332,28 +333,136 @@ def __init__(self, tuple_value): self[:] = tuple_value self.hashvalue = hash(tuple_value) - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return self.hashvalue -class DummyFileManager(FileManager): +def _get_none() -> None: + return None + + +class PickleableFileManager(FileManager[T_File]): + """File manager that supports pickling by reopening a file object. + + Use PickleableFileManager for wrapping file-like objects that do not natively + support pickling (e.g., netCDF4.Dataset and h5netcdf.File) in cases where a + global cache is not desirable (e.g., for netCDF files opened from bytes in + memory, or from existing file objects). + """ + + def __init__( + self, + opener: Callable[..., T_File], + *args: Any, + mode: Any = _OMIT_MODE, + kwargs: Mapping[str, Any] | None = None, + ): + kwargs = {} if kwargs is None else dict(kwargs) + self._opener = opener + self._args = args + self._mode = "a" if mode == "w" else mode + self._kwargs = kwargs + + # Note: No need for locking with PickleableFileManager, because all + # opening of files happens in the constructor. + if mode != _OMIT_MODE: + kwargs = kwargs | {"mode": mode} + self._file: T_File | None = opener(*args, **kwargs) + + @property + def _closed(self) -> bool: + # If opener() raised an error in the constructor, _file may not be set + return getattr(self, "_file", None) is None + + def _get_unclosed_file(self) -> T_File: + if self._closed: + raise RuntimeError("file is closed") + file = self._file + assert file is not None + return file + + def acquire(self, needs_lock: bool = True) -> T_File: + del needs_lock # unused + return self._get_unclosed_file() + + @contextmanager + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_File]: + del needs_lock # unused + yield self._get_unclosed_file() + + def close(self, needs_lock: bool = True) -> None: + del needs_lock # unused + if not self._closed: + file = self._get_unclosed_file() + file.close() + self._file = None + # Remove all references to opener arguments, so they can be garbage + # collected. + self._args = () + self._mode = _OMIT_MODE + self._kwargs = {} + + def __del__(self) -> None: + if not self._closed: + self.close() + + if OPTIONS["warn_for_unclosed_files"]: + warnings.warn( + f"deallocating {self}, but file is not already closed. " + "This may indicate a bug.", + RuntimeWarning, + stacklevel=2, + ) + + def __getstate__(self): + # file is intentionally omitted: we want to open it again + opener = _get_none if self._closed else self._opener + return (opener, self._args, self._mode, self._kwargs) + + def __setstate__(self, state) -> None: + opener, args, mode, kwargs = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs) # type: ignore[misc] + + def __repr__(self) -> str: + if self._closed: + return f"" + args_string = ", ".join(map(repr, self._args)) + if self._mode is not _OMIT_MODE: + args_string += f", mode={self._mode!r}" + kwargs = ( + self._kwargs | {"memory": utils.ReprObject("...")} + if "memory" in self._kwargs + else self._kwargs + ) + return f"{type(self).__name__}({self._opener!r}, {args_string}, {kwargs=})" + + +@atexit.register +def _remove_del_methods(): + # We don't need to close unclosed files at program exit, and may not be able + # to, because Python is cleaning up imports / globals. + del CachingFileManager.__del__ + del PickleableFileManager.__del__ + + +class DummyFileManager(FileManager[T_File]): """FileManager that simply wraps an open file in the FileManager interface.""" - def __init__(self, value, *, close=None): + def __init__(self, value: T_File, *, close: Callable[[], None] | None = None): if close is None: close = value.close self._value = value self._close = close - def acquire(self, needs_lock=True): - del needs_lock # ignored + def acquire(self, needs_lock: bool = True) -> T_File: + del needs_lock # unused return self._value - @contextlib.contextmanager - def acquire_context(self, needs_lock=True): - del needs_lock + @contextmanager + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_File]: + del needs_lock # unused yield self._value - def close(self, needs_lock=True): - del needs_lock # ignored + def close(self, needs_lock: bool = True) -> None: + del needs_lock # unused self._close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 01275d8db5b..422eadc6c34 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -24,6 +24,7 @@ CachingFileManager, DummyFileManager, FileManager, + PickleableFileManager, ) from xarray.backends.locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from xarray.backends.netCDF4_ import ( @@ -216,11 +217,12 @@ def open( else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) - manager = ( - CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) - if isinstance(filename, str) - else h5netcdf.File(filename, mode=mode, **kwargs) + manager_cls = ( + CachingFileManager + if isinstance(filename, str) and not is_remote_uri(filename) + else PickleableFileManager ) + manager = manager_cls(h5netcdf.File, filename, mode=mode, kwargs=kwargs) return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): @@ -647,7 +649,7 @@ def open_groups_as_dict( # only warn if phony_dims exist in file # remove together with the above check # after some versions - if store.ds._phony_dim_count > 0 and emit_phony_dims_warning: + if store.ds._root._phony_dim_count > 0 and emit_phony_dims_warning: _emit_phony_dims_warning() return groups_dict diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 82d3e0b7dae..784443544ee 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -4,15 +4,17 @@ import threading import uuid import weakref -from collections.abc import Hashable, MutableMapping -from typing import Any, ClassVar +from collections.abc import Callable, Hashable, MutableMapping, Sequence +from typing import Any, ClassVar, Literal from weakref import WeakValueDictionary +from xarray.core.types import Lock + # SerializableLock is adapted from Dask: # https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224 # Used under the terms of Dask's license, see licenses/DASK_LICENSE. -class SerializableLock: +class SerializableLock(Lock): """A Serializable per-process Lock This wraps a normal ``threading.Lock`` object and satisfies the same @@ -90,7 +92,7 @@ def __str__(self): _FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary() -def _get_threaded_lock(key): +def _get_threaded_lock(key: str) -> threading.Lock: try: lock = _FILE_LOCKS[key] except KeyError: @@ -98,14 +100,14 @@ def _get_threaded_lock(key): return lock -def _get_multiprocessing_lock(key): +def _get_multiprocessing_lock(key: str) -> Lock: # TODO: make use of the key -- maybe use locket.py? # https://github.com/mwilliamson/locket.py del key # unused return multiprocessing.Lock() -def _get_lock_maker(scheduler=None): +def _get_lock_maker(scheduler: str | None = None) -> Callable[..., Lock]: """Returns an appropriate function for creating resource locks. Parameters @@ -125,10 +127,8 @@ def _get_lock_maker(scheduler=None): elif scheduler == "distributed": # Lazy import distributed since it is can add a significant # amount of time to import - try: - from dask.distributed import Lock as DistributedLock - except ImportError: - DistributedLock = None + from dask.distributed import Lock as DistributedLock + return DistributedLock else: raise KeyError(scheduler) @@ -172,7 +172,7 @@ def get_dask_scheduler(get=None, collection=None) -> str | None: return "threaded" -def get_write_lock(key): +def get_write_lock(key: str) -> Lock: """Get a scheduler appropriate lock for writing to the given resource. Parameters @@ -207,14 +207,14 @@ def acquire(lock, blocking=True): return lock.acquire(blocking) -class CombinedLock: +class CombinedLock(Lock): """A combination of multiple locks. Like a locked door, a CombinedLock is locked if any of its constituent locks are locked. """ - def __init__(self, locks): + def __init__(self, locks: Sequence[Lock]): self.locks = tuple(set(locks)) # remove duplicates def acquire(self, blocking=True): @@ -239,7 +239,7 @@ def __repr__(self): return f"CombinedLock({list(self.locks)!r})" -class DummyLock: +class DummyLock(Lock): """DummyLock provides the lock API without any actual locking.""" def acquire(self, blocking=True): @@ -258,9 +258,9 @@ def locked(self): return False -def combine_locks(locks): +def combine_locks(locks: Sequence[Lock]) -> Lock: """Combine a sequence of locks into a single lock.""" - all_locks = [] + all_locks: list[Lock] = [] for lock in locks: if isinstance(lock, CombinedLock): all_locks.extend(lock.locks) @@ -276,7 +276,7 @@ def combine_locks(locks): return DummyLock() -def ensure_lock(lock): +def ensure_lock(lock: Lock | None | Literal[False]) -> Lock: """Ensure that the given object is a lock.""" if lock is None or lock is False: return DummyLock() diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index d6a37b06d88..234768ef891 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -5,6 +5,8 @@ import os from collections.abc import Iterable from contextlib import suppress +from dataclasses import dataclass +from io import IOBase from typing import TYPE_CHECKING, Any, Self import numpy as np @@ -13,6 +15,7 @@ BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, + BytesIOProxy, T_PathFileOrDataStore, WritableCFDataStore, _normalize_path, @@ -21,7 +24,11 @@ find_root_and_group, robust_getitem, ) -from xarray.backends.file_manager import CachingFileManager, DummyFileManager +from xarray.backends.file_manager import ( + CachingFileManager, + DummyFileManager, + PickleableFileManager, +) from xarray.backends.locks import ( HDF5_LOCK, NETCDFC_LOCK, @@ -48,6 +55,7 @@ from xarray.core.variable import Variable if TYPE_CHECKING: + import netCDF4 from h5netcdf.core import EnumType as h5EnumType from netCDF4 import EnumType as ncEnumType @@ -358,6 +366,28 @@ def _build_and_get_enum( return datatype +@dataclass +class _Thunk: + """Pickleable equivalent of `lambda: value`.""" + + value: Any + + def __call__(self): + return self.value + + +@dataclass +class _CloseWithCopy: + """Wrapper around netCDF4's esoteric interface for writing in-memory data.""" + + proxy: BytesIOProxy + nc4_dataset: netCDF4.Dataset + + def __call__(self): + value = self.nc4_dataset.close() + self.proxy.getvalue = _Thunk(value) + + class NetCDF4DataStore(WritableCFDataStore): """Store for reading and writing data via the Python-NetCDF4 library. @@ -432,27 +462,31 @@ def open( if isinstance(filename, os.PathLike): filename = os.fspath(filename) - if not isinstance(filename, str): - raise ValueError( - "can only read bytes or file-like objects " - "with engine='scipy' or 'h5netcdf'" + if isinstance(filename, IOBase): + raise TypeError( + f"file objects are not supported by the netCDF4 backend: {filename}" ) + if not isinstance(filename, str | bytes | memoryview | BytesIOProxy): + raise TypeError(f"invalid filename for netCDF4 backend: {filename}") + if format is None: format = "NETCDF4" if lock is None: if mode == "r": - if is_remote_uri(filename): + if isinstance(filename, str) and is_remote_uri(filename): lock = NETCDFC_LOCK else: lock = NETCDF4_PYTHON_LOCK else: if format is None or format.startswith("NETCDF4"): - base_lock = NETCDF4_PYTHON_LOCK + lock = NETCDF4_PYTHON_LOCK else: - base_lock = NETCDFC_LOCK - lock = combine_locks([base_lock, get_write_lock(filename)]) + lock = NETCDFC_LOCK + + if isinstance(filename, str): + lock = combine_locks([lock, get_write_lock(filename)]) kwargs = dict( clobber=clobber, @@ -462,9 +496,31 @@ def open( ) if auto_complex is not None: kwargs["auto_complex"] = auto_complex - manager = CachingFileManager( - netCDF4.Dataset, filename, mode=mode, kwargs=kwargs - ) + + if isinstance(filename, BytesIOProxy): + assert mode == "w" + # Size hint used for creating netCDF3 files. Per the documentation + # for nc__create(), the special value NC_SIZEHINT_DEFAULT (which is + # the value 0), lets the netcdf library choose a suitable initial + # size. + memory = 0 + kwargs["diskless"] = False + nc4_dataset = netCDF4.Dataset( + "", mode=mode, memory=memory, **kwargs + ) + close = _CloseWithCopy(filename, nc4_dataset) + manager = DummyFileManager(nc4_dataset, close=close) + + elif isinstance(filename, bytes | memoryview): + assert mode == "r" + kwargs["memory"] = filename + manager = PickleableFileManager( + netCDF4.Dataset, "", mode=mode, kwargs=kwargs + ) + else: + manager = CachingFileManager( + netCDF4.Dataset, filename, mode=mode, kwargs=kwargs + ) return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): @@ -646,7 +702,12 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): return True - magic_number = try_read_magic_number_from_path(filename_or_obj) + + magic_number = ( + bytes(filename_or_obj[:8]) + if isinstance(filename_or_obj, bytes | memoryview) + else try_read_magic_number_from_path(filename_or_obj) + ) if magic_number is not None: # netcdf 3 or HDF5 return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n")) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 40c610bbad0..db188d88b9c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -35,6 +35,12 @@ ) from xarray.core.variable import Variable +try: + from scipy.io import netcdf_file as netcdf_file_base +except ImportError: + netcdf_file_base = object + + if TYPE_CHECKING: import scipy.io @@ -104,13 +110,38 @@ def __setitem__(self, key, value): raise -def _open_scipy_netcdf(filename, mode, mmap, version): +# TODO: Make the scipy import lazy again after upstreaming these fixes. +class flush_only_netcdf_file(netcdf_file_base): + # scipy.io.netcdf_file.close() incorrectly closes file objects that + # were passed in as constructor arguments: + # https://github.com/scipy/scipy/issues/13905 + + # Instead of closing such files, only call flush(), which is + # equivalent as long as the netcdf_file object is not mmapped. + # This suffices to keep BytesIO objects open long enough to read + # their contents from to_netcdf(), but underlying files still get + # closed when the netcdf_file is garbage collected (via __del__), + # and will need to be fixed upstream in scipy. + def close(self): + if hasattr(self, "fp") and not self.fp.closed: + self.flush() + self.fp.seek(0) # allow file to be read again + + def __del__(self): + # Remove the __del__ method, which in scipy is aliased to close(). + # These files need to be closed explicitly by xarray. + pass + + +def _open_scipy_netcdf(filename, mode, mmap, version, flush_only=False): import scipy.io + netcdf_file = flush_only_netcdf_file if flush_only else scipy.io.netcdf_file + # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: - return scipy.io.netcdf_file( + return netcdf_file( gzip.open(filename), mode=mode, mmap=mmap, version=version ) except TypeError as e: @@ -124,7 +155,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): raise try: - return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version) + return netcdf_file(filename, mode=mode, mmap=mmap, version=version) except TypeError as e: # netcdf3 message is obscure in this case errmsg = e.args[0] if "is not a valid NetCDF 3 file" in errmsg: @@ -184,19 +215,14 @@ def __init__( # Note: checking for .seek matches the check for file objects # in scipy.io.netcdf_file scipy_dataset = _open_scipy_netcdf( - filename_or_obj, mode=mode, mmap=mmap, version=version + filename_or_obj, + mode=mode, + mmap=mmap, + version=version, + flush_only=True, ) - # scipy.io.netcdf_file.close() incorrectly closes file objects that - # were passed in as constructor arguments: - # https://github.com/scipy/scipy/issues/13905 - # Instead of closing such files, only call flush(), which is - # equivalent as long as the netcdf_file object is not mmapped. - # This suffices to keep BytesIO objects open long enough to read - # their contents from to_netcdf(), but underlying files still get - # closed when the netcdf_file is garbage collected (via __del__), - # and will need to be fixed upstream in scipy. assert not scipy_dataset.use_mmap # no mmap for file objects - manager = DummyFileManager(scipy_dataset, close=scipy_dataset.flush) + manager = DummyFileManager(scipy_dataset) else: raise ValueError( f"cannot open {filename_or_obj=} with scipy.io.netcdf_file" diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index f425170c271..b158862aba0 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -55,10 +55,10 @@ def _datatree_to_netcdf( filepath = _normalize_path(filepath) if engine is None: - to_fileobject_or_memoryview = not isinstance(filepath, str) + to_fileobject = isinstance(filepath, io.IOBase) engine = get_default_netcdf_write_engine( format="NETCDF4", # required for supporting groups - to_fileobject_or_memoryview=to_fileobject_or_memoryview, + to_fileobject=to_fileobject, ) # type: ignore[assignment] if group is not None: diff --git a/xarray/core/types.py b/xarray/core/types.py index 2305ce56199..a0d62d30c9f 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -364,3 +364,14 @@ def read(self, n: int = ..., /) -> AnyStr_co: ] ResampleCompatible: TypeAlias = str | datetime.timedelta | pd.Timedelta | pd.DateOffset + + +class Closable(Protocol): + def close(self) -> None: ... + + +class Lock(Protocol): + def acquire(self, *args, **kwargs) -> Any: ... + def release(self) -> None: ... + def __enter__(self) -> Any: ... + def __exit__(self, *args, **kwargs) -> None: ... diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 3b4e49c64d8..ce0d39b6ad0 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -179,6 +179,10 @@ def _importorskip( requires_scipy_or_netCDF4 = pytest.mark.skipif( not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" ) +has_h5netcdf_or_netCDF4 = has_h5netcdf or has_netCDF4 +requires_h5netcdf_or_netCDF4 = pytest.mark.skipif( + not has_h5netcdf_or_netCDF4, reason="requires h5netcdf or netCDF4" +) has_numbagg_or_bottleneck = has_numbagg or has_bottleneck requires_numbagg_or_bottleneck = pytest.mark.skipif( not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck" diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index c5a53709d79..715314daddd 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -89,6 +89,7 @@ requires_fsspec, requires_h5netcdf, requires_h5netcdf_1_4_0_or_above, + requires_h5netcdf_or_netCDF4, requires_h5netcdf_ros3, requires_iris, requires_netcdf, @@ -115,12 +116,12 @@ with contextlib.suppress(ImportError): import netCDF4 as nc4 -try: +with contextlib.suppress(ImportError): import dask import dask.array as da -except ImportError: - pass +with contextlib.suppress(ImportError): + import fsspec if has_zarr: import zarr @@ -631,16 +632,13 @@ def test_pickle(self) -> None: with pickle.loads(raw_pickle) as unpickled_ds: assert_identical(expected, unpickled_ds) - @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") def test_pickle_dataarray(self) -> None: expected = Dataset({"foo": ("x", [42])}) with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: with roundtripped: raw_pickle = pickle.dumps(roundtripped["foo"]) - # TODO: figure out how to explicitly close the file for the - # unpickled DataArray? - unpickled = pickle.loads(raw_pickle) - assert_identical(expected["foo"], unpickled) + with pickle.loads(raw_pickle) as unpickled: + assert_identical(expected["foo"], unpickled) def test_dataset_caching(self) -> None: expected = Dataset({"foo": ("x", [5, 6, 7])}) @@ -656,7 +654,6 @@ def test_dataset_caching(self) -> None: _ = actual.foo.values # no caching assert not actual.foo.variable._in_memory - @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") def test_roundtrip_None_variable(self) -> None: expected = Dataset({None: (("x", "y"), [[0, 1], [2, 3]])}) with self.roundtrip(expected) as actual: @@ -2492,6 +2489,70 @@ def test_deepcopy(self) -> None: assert_identical(expected, copied) +class InMemoryNetCDF: + engine: T_NetcdfEngine | None + + def test_roundtrip_via_memoryview(self) -> None: + original = create_test_data() + result = original.to_netcdf(engine=self.engine) + roundtrip = load_dataset(result, engine=self.engine) + assert_identical(roundtrip, original) + + def test_roundtrip_via_bytes(self) -> None: + original = create_test_data() + result = bytes(original.to_netcdf(engine=self.engine)) + roundtrip = load_dataset(result, engine=self.engine) + assert_identical(roundtrip, original) + + def test_pickle_open_dataset_from_bytes(self) -> None: + original = Dataset({"foo": ("x", [1, 2, 3])}) + netcdf_bytes = bytes(original.to_netcdf(engine=self.engine)) + with open_dataset(netcdf_bytes, engine=self.engine) as roundtrip: + with pickle.loads(pickle.dumps(roundtrip)) as unpickled: + assert_identical(unpickled, original) + + def test_compute_false(self) -> None: + original = create_test_data() + with pytest.raises( + NotImplementedError, + match=re.escape("to_netcdf() with compute=False is not yet implemented"), + ): + original.to_netcdf(engine=self.engine, compute=False) + + +class InMemoryNetCDFWithGroups(InMemoryNetCDF): + def test_roundtrip_group_via_memoryview(self) -> None: + original = create_test_data() + netcdf_bytes = original.to_netcdf(group="sub", engine=self.engine) + roundtrip = load_dataset(netcdf_bytes, group="sub", engine=self.engine) + assert_identical(roundtrip, original) + + +class FileObjectNetCDF: + engine: T_NetcdfEngine + + def test_file_remains_open(self) -> None: + data = Dataset({"foo": ("x", [1, 2, 3])}) + f = BytesIO() + data.to_netcdf(f, engine=self.engine) + assert not f.closed + restored = open_dataset(f, engine=self.engine) + assert not f.closed + assert_identical(restored, data) + restored.close() + assert not f.closed + + +@requires_h5netcdf_or_netCDF4 +class TestGenericNetCDF4InMemory(InMemoryNetCDFWithGroups): + engine = None + + +@requires_netCDF4 +class TestNetCDF4InMemory(InMemoryNetCDFWithGroups): + engine: T_NetcdfEngine = "netcdf4" + + @requires_netCDF4 @requires_dask @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") @@ -4495,7 +4556,7 @@ def test_zarr_version_deprecated() -> None: @requires_scipy -class TestScipyInMemoryData(NetCDF3Only, CFEncodedBase): +class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only, InMemoryNetCDF): engine: T_NetcdfEngine = "scipy" @contextlib.contextmanager @@ -4503,38 +4564,26 @@ def create_store(self): fobj = BytesIO() yield backends.ScipyDataStore(fobj, "w") + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + saved = self.save(data, path=None, **save_kwargs) + with self.open(saved, **open_kwargs) as ds: + yield ds + @pytest.mark.asyncio @pytest.mark.skip(reason="NetCDF backends don't support async loading") async def test_load_async(self) -> None: await super().test_load_async() - def test_to_netcdf_explicit_engine(self) -> None: - Dataset({"foo": 42}).to_netcdf(engine="scipy") - - def test_roundtrip_via_bytes(self) -> None: - original = create_test_data() - netcdf_bytes = original.to_netcdf(engine="scipy") - roundtrip = open_dataset(netcdf_bytes, engine="scipy") - assert_identical(roundtrip, original) - - def test_to_bytes_compute_false(self) -> None: - original = create_test_data() - with pytest.raises( - NotImplementedError, - match=re.escape("to_netcdf() with compute=False is not yet implemented"), - ): - original.to_netcdf(engine="scipy", compute=False) - - def test_bytes_pickle(self) -> None: - data = Dataset({"foo": ("x", [1, 2, 3])}) - fobj = data.to_netcdf(engine="scipy") - with self.open(fobj) as ds: - unpickled = pickle.loads(pickle.dumps(ds)) - assert_identical(unpickled, data) - @requires_scipy -class TestScipyFileObject(NetCDF3Only, CFEncodedBase): +class TestScipyFileObject(CFEncodedBase, NetCDF3Only, FileObjectNetCDF): # TODO: Consider consolidating some of these cases (e.g., # test_file_remains_open) with TestH5NetCDFFileObject engine: T_NetcdfEngine = "scipy" @@ -4559,27 +4608,18 @@ def roundtrip( with self.open(f, **open_kwargs) as ds: yield ds - @pytest.mark.xfail( - reason="scipy.io.netcdf_file closes files upon garbage collection" - ) - def test_file_remains_open(self) -> None: - data = Dataset({"foo": ("x", [1, 2, 3])}) - f = BytesIO() - data.to_netcdf(f, engine="scipy") - assert not f.closed - restored = open_dataset(f, engine="scipy") - assert not f.closed - assert_identical(restored, data) - restored.close() - assert not f.closed + @pytest.mark.asyncio + @pytest.mark.skip(reason="NetCDF backends don't support async loading") + async def test_load_async(self) -> None: + await super().test_load_async() @pytest.mark.skip(reason="cannot pickle file objects") def test_pickle(self) -> None: - pass + super().test_pickle() @pytest.mark.skip(reason="cannot pickle file objects") def test_pickle_dataarray(self) -> None: - pass + super().test_pickle_dataarray() @pytest.mark.parametrize("create_default_indexes", [True, False]) def test_create_default_indexes(self, tmp_path, create_default_indexes) -> None: @@ -4692,19 +4732,23 @@ def test_engine(self) -> None: with pytest.raises(ValueError, match=r"unrecognized engine"): data.to_netcdf("foo.nc", engine="foobar") # type: ignore[call-overload] - with pytest.raises( - ValueError, - match=re.escape( - "can only read bytes or file-like objects with engine='scipy' or 'h5netcdf'" - ), - ): - data.to_netcdf(engine="netcdf4") - with create_tmp_file() as tmp_file: data.to_netcdf(tmp_file) with pytest.raises(ValueError, match=r"unrecognized engine"): open_dataset(tmp_file, engine="foobar") + with pytest.raises( + TypeError, + match=re.escape("file objects are not supported by the netCDF4 backend"), + ): + data.to_netcdf(BytesIO(), engine="netcdf4") + + with pytest.raises( + TypeError, + match=re.escape("file objects are not supported by the netCDF4 backend"), + ): + open_dataset(BytesIO(), engine="netcdf4") + bytes_io = BytesIO() data.to_netcdf(bytes_io, engine="scipy") with pytest.raises(ValueError, match=r"unrecognized engine"): @@ -5042,7 +5086,7 @@ def test_deepcopy(self) -> None: @requires_h5netcdf -class TestH5NetCDFFileObject(TestH5NetCDFData): +class TestH5NetCDFFileObject(TestH5NetCDFData, FileObjectNetCDF): engine: T_NetcdfEngine = "h5netcdf" def test_open_badbytes(self) -> None: @@ -5051,8 +5095,10 @@ def test_open_badbytes(self) -> None: ): with open_dataset(b"garbage"): pass - with pytest.raises(ValueError, match=r"can only read bytes"): - with open_dataset(b"garbage", engine="netcdf4"): + with pytest.raises( + ValueError, match=r"not the signature of a valid netCDF4 file" + ): + with open_dataset(b"garbage", engine="h5netcdf"): pass with pytest.raises( ValueError, match=r"not the signature of a valid netCDF4 file" @@ -5062,13 +5108,12 @@ def test_open_badbytes(self) -> None: def test_open_twice(self) -> None: expected = create_test_data() - expected.attrs["foo"] = "bar" with create_tmp_file() as tmp_file: - expected.to_netcdf(tmp_file, engine="h5netcdf") + expected.to_netcdf(tmp_file, engine=self.engine) with open(tmp_file, "rb") as f: - with open_dataset(f, engine="h5netcdf"): - with open_dataset(f, engine="h5netcdf"): - pass + with open_dataset(f, engine=self.engine): + with open_dataset(f, engine=self.engine): + pass # should not crash @requires_scipy def test_open_fileobj(self) -> None: @@ -5103,31 +5148,24 @@ def test_open_fileobj(self) -> None: with open_dataset(f): # ensure file gets closed pass - def test_file_remains_open(self) -> None: - data = Dataset({"foo": ("x", [1, 2, 3])}) - f = BytesIO() - data.to_netcdf(f, engine="h5netcdf") - assert not f.closed - restored = open_dataset(f, engine="h5netcdf") - assert not f.closed - assert_identical(restored, data) - restored.close() - assert not f.closed + @requires_fsspec + def test_fsspec(self) -> None: + expected = create_test_data() + with create_tmp_file() as tmp_file: + expected.to_netcdf(tmp_file, engine="h5netcdf") + with fsspec.open(tmp_file, "rb") as f: + with open_dataset(f, engine="h5netcdf") as actual: + assert_identical(actual, expected) -@requires_h5netcdf -class TestH5NetCDFInMemoryData: - def test_roundtrip_via_bytes(self) -> None: - original = create_test_data() - netcdf_bytes = original.to_netcdf(engine="h5netcdf") - roundtrip = open_dataset(netcdf_bytes, engine="h5netcdf") - assert_identical(roundtrip, original) + # fsspec.open() creates a pickleable file, unlike open() + with pickle.loads(pickle.dumps(actual)) as unpickled: + assert_identical(unpickled, expected) - def test_roundtrip_group_via_bytes(self) -> None: - original = create_test_data() - netcdf_bytes = original.to_netcdf(group="sub", engine="h5netcdf") - roundtrip = open_dataset(netcdf_bytes, group="sub", engine="h5netcdf") - assert_identical(roundtrip, original) + +@requires_h5netcdf +class TestH5NetCDFInMemoryData(InMemoryNetCDFWithGroups): + engine: T_NetcdfEngine = "h5netcdf" @requires_h5netcdf @@ -5170,6 +5208,25 @@ def test_write_inconsistent_chunks(self) -> None: assert actual["y"].encoding["chunksizes"] == (100, 50) +@requires_netCDF4 +@requires_h5netcdf +def test_memoryview_write_h5netcdf_read_netcdf4() -> None: + original = create_test_data() + result = original.to_netcdf(engine="h5netcdf") + roundtrip = load_dataset(result, engine="netcdf4") + assert_identical(roundtrip, original) + + +@requires_netCDF4 +@requires_h5netcdf +def test_memoryview_write_netcdf4_read_h5netcdf() -> None: + original = create_test_data() + result = original.to_netcdf(engine="netcdf4") + roundtrip = load_dataset(result, engine="h5netcdf") + assert_identical(roundtrip, original) + + +@network @requires_h5netcdf_ros3 class TestH5NetCDFDataRos3Driver(TestCommon): engine: T_NetcdfEngine = "h5netcdf" @@ -7091,6 +7148,9 @@ def test_netcdf4_entrypoint(tmp_path: Path) -> None: assert entrypoint.guess_can_open("something-local.cdf") assert not entrypoint.guess_can_open("not-found-and-no-extension") + contents = ds.to_netcdf(engine="netcdf4") + _check_guess_can_open_and_open(entrypoint, contents, engine="netcdf4", expected=ds) + path = tmp_path / "baz" with open(path, "wb") as f: f.write(b"not-a-netcdf-file") @@ -7141,6 +7201,9 @@ def test_h5netcdf_entrypoint(tmp_path: Path) -> None: with open(path, "rb") as f: _check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds) + contents = ds.to_netcdf(engine="h5netcdf") + _check_guess_can_open_and_open(entrypoint, contents, engine="h5netcdf", expected=ds) + assert entrypoint.guess_can_open("something-local.nc") assert entrypoint.guess_can_open("something-local.nc4") assert entrypoint.guess_can_open("something-local.cdf") diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index 2d659dcb9c9..b222c537ee4 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -23,23 +23,17 @@ @requires_scipy @requires_h5netcdf def test_get_default_netcdf_write_engine() -> None: - engine = get_default_netcdf_write_engine( - format=None, to_fileobject_or_memoryview=False - ) + engine = get_default_netcdf_write_engine(format=None, to_fileobject=False) assert engine == "netcdf4" - engine = get_default_netcdf_write_engine( - format="NETCDF4", to_fileobject_or_memoryview=False - ) + engine = get_default_netcdf_write_engine(format="NETCDF4", to_fileobject=False) assert engine == "netcdf4" - engine = get_default_netcdf_write_engine( - format="NETCDF4", to_fileobject_or_memoryview=True - ) + engine = get_default_netcdf_write_engine(format="NETCDF4", to_fileobject=True) assert engine == "h5netcdf" engine = get_default_netcdf_write_engine( - format="NETCDF3_CLASSIC", to_fileobject_or_memoryview=True + format="NETCDF3_CLASSIC", to_fileobject=True ) assert engine == "scipy" @@ -52,20 +46,17 @@ def test_default_engine_h5netcdf(monkeypatch): monkeypatch.delitem(sys.modules, "scipy", raising=False) monkeypatch.setattr(sys, "meta_path", []) - engine = get_default_netcdf_write_engine( - format=None, to_fileobject_or_memoryview=False - ) + engine = get_default_netcdf_write_engine(format=None, to_fileobject=False) assert engine == "h5netcdf" with pytest.raises( ValueError, match=re.escape( - "cannot write NetCDF files with format='NETCDF3_CLASSIC' because none of the suitable backend libraries (netCDF4, scipy) are installed" + "cannot write NetCDF files with format='NETCDF3_CLASSIC' because " + "none of the suitable backend libraries (netCDF4, scipy) are installed" ), ): - get_default_netcdf_write_engine( - format="NETCDF3_CLASSIC", to_fileobject_or_memoryview=False - ) + get_default_netcdf_write_engine(format="NETCDF3_CLASSIC", to_fileobject=False) def test_custom_engine() -> None: diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index de6c7499c9f..8d03781592d 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -11,8 +11,7 @@ import pytest import xarray as xr -from xarray.backends.api import open_datatree, open_groups -from xarray.core.datatree import DataTree +from xarray import DataTree, load_datatree, open_datatree, open_groups from xarray.testing import assert_equal, assert_identical from xarray.tests import ( has_zarr_v3, @@ -20,6 +19,7 @@ parametrize_zarr_format, requires_dask, requires_h5netcdf, + requires_h5netcdf_or_netCDF4, requires_netCDF4, requires_pydap, requires_zarr, @@ -33,6 +33,9 @@ import netCDF4 as nc4 +ON_WINDOWS = sys.platform == "win32" + + class TestNetCDF4DataTree(_TestNetCDF4Data): @contextlib.contextmanager def open(self, path, **kwargs): @@ -303,9 +306,12 @@ def test_compute_false(self, tmpdir, simple_datatree): original_dt = simple_datatree.chunk() result = original_dt.to_netcdf(filepath, engine=self.engine, compute=False) - with open_datatree(filepath, engine=self.engine) as in_progress_dt: - assert in_progress_dt.isomorphic(original_dt) - assert not in_progress_dt.equals(original_dt) + if not ON_WINDOWS: + # File at filepath is not closed until .compute() is called. On + # Windows, this means we can't open it yet. + with open_datatree(filepath, engine=self.engine) as in_progress_dt: + assert in_progress_dt.isomorphic(original_dt) + assert not in_progress_dt.equals(original_dt) result.compute() with open_datatree(filepath, engine=self.engine) as written_dt: @@ -322,26 +328,11 @@ def test_default_write_engine(self, tmpdir, simple_datatree, monkeypatch): original_dt.to_netcdf(filepath) # should not raise -@requires_netCDF4 -class TestNetCDF4DatatreeIO(DatatreeIOBase): - engine: T_DataTreeNetcdfEngine | None = "netcdf4" - - def test_open_datatree(self, unaligned_datatree_nc) -> None: - """Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy.""" - - with pytest.raises( - ValueError, - match=( - re.escape( - "group '/Group1/subgroup1' is not aligned with its parents:\nGroup:\n" - ) - + ".*" - ), - ): - open_datatree(unaligned_datatree_nc) +class NetCDFIOBase(DatatreeIOBase): + engine: T_DataTreeNetcdfEngine | None @requires_dask - def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None: + def test_open_datatree_chunks(self, tmpdir) -> None: filepath = tmpdir / "test.nc" chunks = {"x": 2, "y": 1} @@ -356,13 +347,72 @@ def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None: "/group2": set2_data.chunk(chunks), } ) - original_tree.to_netcdf(filepath, engine="netcdf4") + original_tree.to_netcdf(filepath, engine=self.engine) - with open_datatree(filepath, engine="netcdf4", chunks=chunks) as tree: + with open_datatree(filepath, engine=self.engine, chunks=chunks) as tree: xr.testing.assert_identical(tree, original_tree) assert_chunks_equal(tree, original_tree, enforce_dask=True) + # def test_roundtrip_via_memoryview(self, simple_datatree) -> None: + # original_dt = simple_datatree + # memview = original_dt.to_netcdf(engine=self.engine) + # roundtrip_dt = load_datatree(memview, engine=self.engine) + # assert_equal(original_dt, roundtrip_dt) + + def test_to_bytes_compute_false(self, simple_datatree) -> None: + original_dt = simple_datatree + with pytest.raises( + NotImplementedError, + match=re.escape("to_netcdf() with compute=False is not yet implemented"), + ): + original_dt.to_netcdf(engine=self.engine, compute=False) + + def test_open_datatree_specific_group(self, tmpdir, simple_datatree) -> None: + """Test opening a specific group within a NetCDF file using `open_datatree`.""" + filepath = tmpdir / "test.nc" + group = "/set1" + original_dt = simple_datatree + original_dt.to_netcdf(filepath, engine=self.engine) + expected_subtree = original_dt[group].copy() + expected_subtree.orphan() + with open_datatree(filepath, group=group, engine=self.engine) as subgroup_tree: + assert subgroup_tree.root.parent is None + assert_equal(subgroup_tree, expected_subtree) + + +@requires_h5netcdf_or_netCDF4 +class TestGenericNetCDFIO(NetCDFIOBase): + engine: T_DataTreeNetcdfEngine | None = None + + @requires_h5netcdf + @requires_netCDF4 + def test_memoryview_write_h5netcdf_read_netcdf4(self, simple_datatree) -> None: + original_dt = simple_datatree + memview = original_dt.to_netcdf(engine="h5netcdf") + roundtrip_dt = load_datatree(memview, engine="netcdf4") + assert_equal(original_dt, roundtrip_dt) + + @requires_h5netcdf + @requires_netCDF4 + def test_memoryview_write_netcdf4_read_h5netcdf(self, simple_datatree) -> None: + original_dt = simple_datatree + memview = original_dt.to_netcdf(engine="netcdf4") + roundtrip_dt = load_datatree(memview, engine="h5netcdf") + assert_equal(original_dt, roundtrip_dt) + + def test_open_datatree_unaligned_hierarchy(self, unaligned_datatree_nc) -> None: + with pytest.raises( + ValueError, + match=( + re.escape( + "group '/Group1/subgroup1' is not aligned with its parents:\nGroup:\n" + ) + + ".*" + ), + ): + open_datatree(unaligned_datatree_nc) + def test_open_groups(self, unaligned_datatree_nc) -> None: """Test `open_groups` with a netCDF4 file with an unaligned group hierarchy.""" unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc) @@ -405,7 +455,7 @@ def test_open_groups_chunks(self, tmpdir) -> None: ) original_tree.to_netcdf(filepath, mode="w") - dict_of_datasets = open_groups(filepath, engine="netcdf4", chunks=chunks) + dict_of_datasets = open_groups(filepath, chunks=chunks) for path, ds in dict_of_datasets.items(): assert {k: max(vs) for k, vs in ds.chunksizes.items()} == chunks, ( @@ -415,6 +465,11 @@ def test_open_groups_chunks(self, tmpdir) -> None: for ds in dict_of_datasets.values(): ds.close() + +@requires_netCDF4 +class TestNetCDF4DatatreeIO(NetCDFIOBase): + engine: T_DataTreeNetcdfEngine | None = "netcdf4" + def test_open_groups_to_dict(self, tmpdir) -> None: """Create an aligned netCDF4 with the following structure to test `open_groups` and `DataTree.from_dict`. @@ -462,17 +517,45 @@ def test_open_groups_to_dict(self, tmpdir) -> None: for ds in aligned_dict_of_datasets.values(): ds.close() - def test_open_datatree_specific_group(self, tmpdir, simple_datatree) -> None: - """Test opening a specific group within a NetCDF file using `open_datatree`.""" - filepath = tmpdir / "test.nc" - group = "/set1" + +@requires_h5netcdf +class TestH5NetCDFDatatreeIO(NetCDFIOBase): + engine: T_DataTreeNetcdfEngine | None = "h5netcdf" + + def test_phony_dims_warning(self, tmpdir) -> None: + filepath = tmpdir + "/phony_dims.nc" + import h5py + + foo_data = np.arange(125).reshape(5, 5, 5) + bar_data = np.arange(625).reshape(25, 5, 5) + var = {"foo1": foo_data, "foo2": bar_data, "foo3": foo_data, "foo4": bar_data} + with h5py.File(filepath, "w") as f: + grps = ["bar", "baz"] + for grp in grps: + fx = f.create_group(grp) + for k, v in var.items(): + fx.create_dataset(k, data=v) + + with pytest.warns(UserWarning, match="The 'phony_dims' kwarg"): + with open_datatree(filepath, engine=self.engine) as tree: + assert tree.bar.dims == { + "phony_dim_0": 5, + "phony_dim_1": 5, + "phony_dim_2": 5, + "phony_dim_3": 25, + } + + def test_roundtrip_using_filelike_object(self, tmpdir, simple_datatree) -> None: original_dt = simple_datatree - original_dt.to_netcdf(filepath) - expected_subtree = original_dt[group].copy() - expected_subtree.orphan() - with open_datatree(filepath, group=group, engine=self.engine) as subgroup_tree: - assert subgroup_tree.root.parent is None - assert_equal(subgroup_tree, expected_subtree) + filepath = tmpdir + "/test.nc" + # h5py requires both read and write access when writing, it will + # work with file-like objects provided they support both, and are + # seekable. + with open(filepath, "wb+") as file: + original_dt.to_netcdf(file, engine=self.engine) + with open(filepath, "rb") as file: + with open_datatree(file, engine=self.engine) as roundtrip_dt: + assert_equal(original_dt, roundtrip_dt) @network @@ -490,9 +573,9 @@ class TestPyDAPDatatreeIO: ) simplegroup_datatree_url = "dap4://test.opendap.org/opendap/dap4/SimpleGroup.nc4.h5" - def test_open_datatree(self, url=unaligned_datatree_url) -> None: - """Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy.""" - + def test_open_datatree_unaligned_hierarchy( + self, url=unaligned_datatree_url + ) -> None: with pytest.raises( ValueError, match=( @@ -568,64 +651,6 @@ def test_open_groups_to_dict(self, url=all_aligned_child_nodes_url) -> None: assert opened_tree.identical(aligned_dt) -@requires_h5netcdf -class TestH5NetCDFDatatreeIO(DatatreeIOBase): - engine: T_DataTreeNetcdfEngine | None = "h5netcdf" - - def test_phony_dims_warning(self, tmpdir) -> None: - filepath = tmpdir + "/phony_dims.nc" - import h5py - - foo_data = np.arange(125).reshape(5, 5, 5) - bar_data = np.arange(625).reshape(25, 5, 5) - var = {"foo1": foo_data, "foo2": bar_data, "foo3": foo_data, "foo4": bar_data} - with h5py.File(filepath, "w") as f: - grps = ["bar", "baz"] - for grp in grps: - fx = f.create_group(grp) - for k, v in var.items(): - fx.create_dataset(k, data=v) - - with pytest.warns(UserWarning, match="The 'phony_dims' kwarg"): - with open_datatree(filepath, engine=self.engine) as tree: - assert tree.bar.dims == { - "phony_dim_0": 5, - "phony_dim_1": 5, - "phony_dim_2": 5, - "phony_dim_3": 25, - } - - def test_roundtrip_via_bytes(self, simple_datatree) -> None: - original_dt = simple_datatree - roundtrip_dt = open_datatree(original_dt.to_netcdf()) - assert_equal(original_dt, roundtrip_dt) - - def test_roundtrip_via_bytes_engine_specified(self, simple_datatree) -> None: - original_dt = simple_datatree - roundtrip_dt = open_datatree(original_dt.to_netcdf(engine=self.engine)) - assert_equal(original_dt, roundtrip_dt) - - def test_to_bytes_compute_false(self, simple_datatree) -> None: - original_dt = simple_datatree - with pytest.raises( - NotImplementedError, - match=re.escape("to_netcdf() with compute=False is not yet implemented"), - ): - original_dt.to_netcdf(compute=False) - - def test_roundtrip_using_filelike_object(self, tmpdir, simple_datatree) -> None: - original_dt = simple_datatree - filepath = tmpdir + "/test.nc" - # h5py requires both read and write access when writing, it will - # work with file-like objects provided they support both, and are - # seekable. - with open(filepath, "wb+") as file: - original_dt.to_netcdf(file, engine=self.engine) - with open(filepath, "rb") as file: - with open_datatree(file, engine=self.engine) as roundtrip_dt: - assert_equal(original_dt, roundtrip_dt) - - @requires_zarr @parametrize_zarr_format class TestZarrDatatreeIO: @@ -796,13 +821,12 @@ def assert_expected_zarr_files_exist( zarr_format=zarr_format, ) - with open_datatree(str(storepath), engine="zarr") as in_progress_dt: - assert in_progress_dt.isomorphic(original_dt) - assert not in_progress_dt.equals(original_dt) + in_progress_dt = load_datatree(str(storepath), engine="zarr") + assert not in_progress_dt.equals(original_dt) result.compute() - with open_datatree(str(storepath), engine="zarr") as written_dt: - assert_identical(written_dt, original_dt) + written_dt = load_datatree(str(storepath), engine="zarr") + assert_identical(written_dt, original_dt) @requires_dask def test_rplus_mode( @@ -873,8 +897,9 @@ def test_open_groups_round_trip(self, tmpdir, simple_datatree, zarr_format) -> N @pytest.mark.filterwarnings( "ignore:Failed to open Zarr store with consolidated metadata:RuntimeWarning" ) - def test_open_datatree(self, unaligned_datatree_zarr_factory, zarr_format) -> None: - """Test if `open_datatree` fails to open a zarr store with an unaligned group hierarchy.""" + def test_open_datatree_unaligned_hierarchy( + self, unaligned_datatree_zarr_factory, zarr_format + ) -> None: storepath = unaligned_datatree_zarr_factory(zarr_format=zarr_format) with pytest.raises( diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index ab1ac4a06d9..7b5fbee6309 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -7,7 +7,7 @@ import pytest -from xarray.backends.file_manager import CachingFileManager +from xarray.backends.file_manager import CachingFileManager, PickleableFileManager from xarray.backends.lru_cache import LRUCache from xarray.core.options import set_options from xarray.tests import assert_no_warnings @@ -262,3 +262,30 @@ class AcquisitionError(Exception): assert file_cache # file *was* already open manager.close() + + +def test_pickleable_file_manager_write_pickle(tmpdir) -> None: + path = str(tmpdir.join("testing.txt")) + manager = PickleableFileManager(open, path, mode="w") + f = manager.acquire() + f.write("foo") + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write("bar") + manager2.close() + manager.close() + + with open(path) as f: + assert f.read() == "foobar" + + +def test_pickleable_file_manager_preserves_closed(tmpdir) -> None: + path = str(tmpdir.join("testing.txt")) + manager = PickleableFileManager(open, path, mode="w") + f = manager.acquire() + f.write("foo") + manager.close() + manager2 = pickle.loads(pickle.dumps(manager)) + assert manager2._closed + assert repr(manager2) == "" diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index f1342a1f82a..c23a5487bd6 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -227,7 +227,7 @@ def test_lazy_import() -> None: "numbagg", "pint", "pydap", - "scipy", + # "scipy", # TODO: xarray.backends.scipy_ is currently not lazy "sparse", "zarr", ]