diff --git a/src/spatialdata_io/_utils.py b/src/spatialdata_io/_utils.py index c73a4d69..e60010a6 100644 --- a/src/spatialdata_io/_utils.py +++ b/src/spatialdata_io/_utils.py @@ -2,17 +2,28 @@ import functools import warnings +from contextlib import contextmanager from typing import TYPE_CHECKING, Any, TypeVar +import zarr + if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Generator + + from zarr import Array, Group + from zarr.core.common import ( + AccessModeLiteral, + ) + from zarr.storage import StoreLike RT = TypeVar("RT") # these two functions should be removed and imported from spatialdata._utils once the multi_table branch, which # introduces them, is merged -def deprecation_alias(**aliases: str) -> Callable[[Callable[..., RT]], Callable[..., RT]]: +def deprecation_alias( + **aliases: str, +) -> Callable[[Callable[..., RT]], Callable[..., RT]]: """Decorate a function to warn user of use of arguments set for deprecation. Parameters @@ -52,7 +63,12 @@ def wrapper(*args: Any, **kwargs: Any) -> RT: return deprecation_decorator -def rename_kwargs(func_name: str, kwargs: dict[str, Any], aliases: dict[str, str], class_name: None | str) -> None: +def rename_kwargs( + func_name: str, + kwargs: dict[str, Any], + aliases: dict[str, str], + class_name: None | str, +) -> None: """Rename function arguments set for deprecation and gives warning in case of usage of these arguments.""" for alias, new in aliases.items(): if alias in kwargs: @@ -71,3 +87,17 @@ def rename_kwargs(func_name: str, kwargs: dict[str, Any], aliases: dict[str, str stacklevel=3, ) kwargs[new] = kwargs.pop(alias) + + +# workaround until https://github.com/zarr-developers/zarr-python/issues/2619 is closed +@contextmanager +def zarr_open( + store: StoreLike | None = None, + *, + mode: AccessModeLiteral | None = None, +) -> Generator[Array | Group, Any, None]: + f = zarr.open(store=store, mode=mode) + try: + yield f + finally: + f.store.close() diff --git a/src/spatialdata_io/readers/xenium.py b/src/spatialdata_io/readers/xenium.py index 31585c35..048660ce 100644 --- a/src/spatialdata_io/readers/xenium.py +++ b/src/spatialdata_io/readers/xenium.py @@ -38,7 +38,7 @@ from spatialdata_io._constants._constants import XeniumKeys from spatialdata_io._docs import inject_docs -from spatialdata_io._utils import deprecation_alias +from spatialdata_io._utils import deprecation_alias, zarr_open from spatialdata_io.readers._utils._read_10x_h5 import _read_10x_h5 from spatialdata_io.readers._utils._utils import _initialize_raster_models_kwargs @@ -417,7 +417,7 @@ def _get_labels_and_indices_mapping( with zipfile.ZipFile(zip_file, "r") as zip_ref: zip_ref.extractall(tmpdir) - with zarr.open(str(tmpdir), mode="r") as z: + with zarr_open(str(tmpdir), mode="r") as z: # get the labels masks = z["masks"][f"{mask_index}"][...] labels = Labels2DModel.parse( @@ -492,7 +492,7 @@ def _get_cells_metadata_table_from_zarr( with zipfile.ZipFile(zip_file, "r") as zip_ref: zip_ref.extractall(tmpdir) - with zarr.open(str(tmpdir), mode="r") as z: + with zarr_open(str(tmpdir), mode="r") as z: x = z["cell_summary"][...] column_names = z["cell_summary"].attrs["column_names"] df = pd.DataFrame(x, columns=column_names)