diff --git a/pyproject.toml b/pyproject.toml index 4ae8ee384..487f02e3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,8 +46,7 @@ dependencies = [ "packaging>=24.2", "array_api_compat>=1.7.1", "legacy-api-wrap", - # <3.1 on account of https://github.com/scverse/anndata/pull/1995 - "zarr >=2.18.7, !=3.0.0, !=3.0.1, !=3.0.2, !=3.0.3, !=3.0.4, !=3.0.5, !=3.0.6, !=3.0.7, <3.1", + "zarr >=2.18.7, !=3.0.*", ] dynamic = [ "version" ] diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index aaf18524f..9b93258d5 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -245,7 +245,7 @@ def to_memory(self, *, copy: bool = False) -> pd.DataFrame: if df.index.name != index_key and index_key is not None: df = df.set_index(index_key) for col in set(self.columns) - non_nullable_string_cols: - df[col] = pd.array(self[col].data, dtype="string") + df[col] = df[col].astype(dtype="string") df.index.name = None # matches old AnnData object return df diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 78700470b..97d1a8640 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -24,6 +24,7 @@ from anndata._io.utils import H5PY_V3, check_key, zero_dim_array_as_scalar from anndata._warnings import OldFormatWarning from anndata.compat import ( + NULLABLE_NUMPY_STRING_TYPE, AwkArray, CupyArray, CupyCSCMatrix, @@ -431,7 +432,7 @@ def write_basic( dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs) f.create_array(k, shape=elem.shape, dtype=dtype, **dataset_kwargs) # see https://github.com/zarr-developers/zarr-python/discussions/2712 - if isinstance(elem, ZarrArray): + if isinstance(elem, ZarrArray | H5Array): f[k][...] = elem[...] else: f[k][...] = elem @@ -622,24 +623,20 @@ def write_vlen_string_array_zarr( f[k][:] = elem else: from numcodecs import VLenUTF8 + from zarr.core.dtype import VariableLengthUTF8 dataset_kwargs = dataset_kwargs.copy() dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs) - match ( - ad.settings.zarr_write_format, - Version(np.__version__) >= Version("2.0.0"), - ): - case 2, _: - filters, dtype = [VLenUTF8()], object - case 3, True: - filters, dtype = None, np.dtypes.StringDType() - case 3, False: - filters, dtype = None, np.dtypes.ObjectDType() + dtype = VariableLengthUTF8() + filters, fill_value = None, None + if ad.settings.zarr_write_format == 2: + filters, fill_value = [VLenUTF8()], "" f.create_array( k, shape=elem.shape, dtype=dtype, filters=filters, + fill_value=fill_value, **dataset_kwargs, ) f[k][:] = elem @@ -1210,7 +1207,10 @@ def _string_array( values: np.ndarray, mask: np.ndarray ) -> pd.api.extensions.ExtensionArray: """Construct a string array from values and mask.""" - arr = pd.array(values, dtype=pd.StringDtype()) + arr = pd.array( + values.astype(NULLABLE_NUMPY_STRING_TYPE), + dtype=pd.StringDtype(), + ) arr[mask] = pd.NA return arr @@ -1281,19 +1281,21 @@ def write_scalar_zarr( return f.create_dataset(key, data=np.array(value), shape=(), **dataset_kwargs) else: from numcodecs import VLenUTF8 + from zarr.core.dtype import VariableLengthUTF8 match ad.settings.zarr_write_format, value: case 2, str(): - filters, dtype = [VLenUTF8()], object + filters, dtype, fill_value = [VLenUTF8()], VariableLengthUTF8(), "" case 3, str(): - filters, dtype = None, np.dtypes.StringDType() + filters, dtype, fill_value = None, VariableLengthUTF8(), None case _, _: - filters, dtype = None, np.array(value).dtype + filters, dtype, fill_value = None, np.array(value).dtype, None a = f.create_array( key, shape=(), dtype=dtype, filters=filters, + fill_value=fill_value, **dataset_kwargs, ) a[...] = np.array(value) diff --git a/src/anndata/_io/zarr.py b/src/anndata/_io/zarr.py index 01a93829a..3b9667300 100644 --- a/src/anndata/_io/zarr.py +++ b/src/anndata/_io/zarr.py @@ -27,19 +27,6 @@ T = TypeVar("T") -def _check_rec_array(adata: AnnData) -> None: - if settings.zarr_write_format == 3 and ( - structured_dtype_keys := { - k - for k, v in adata.uns.items() - if isinstance(v, np.recarray) - or (isinstance(v, np.ndarray) and v.dtype.fields) - } - ): - msg = f"zarr v3 does not support structured dtypes. Found keys {structured_dtype_keys}" - raise NotImplementedError(msg) - - @no_write_dataset_2d def write_zarr( store: StoreLike, @@ -50,7 +37,6 @@ def write_zarr( **ds_kwargs, ) -> None: """See :meth:`~anndata.AnnData.write_zarr`.""" - _check_rec_array(adata) if isinstance(store, Path): store = str(store) if convert_strings_to_categoricals: diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index b8c63e4f2..00e81a80d 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -190,6 +190,13 @@ def old_positionals(*old_positionals): ############################# +NULLABLE_NUMPY_STRING_TYPE = ( + np.dtype("O") + if Version(np.__version__) < Version("2") + else np.dtypes.StringDType(na_object=pd.NA) +) + + @singledispatch def _read_attr(attrs: Mapping, name: str, default: Any | None = Empty): if default is Empty: @@ -404,10 +411,3 @@ def _map_cat_to_str(cat: pd.Categorical) -> pd.Categorical: return cat.map(str, na_action="ignore") else: return cat.map(str) - - -NULLABLE_NUMPY_STRING_TYPE = ( - np.dtype("O") - if Version(np.__version__) < Version("2") - else np.dtypes.StringDType(na_object=pd.NA) -) diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 7aafa17bd..3158fafc5 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -8,14 +8,15 @@ from anndata._core.index import _subset from anndata._core.views import as_view from anndata._io.specs.lazy_methods import get_chunksize -from anndata.compat import H5Array, ZarrArray from ..._settings import settings from ...compat import ( NULLABLE_NUMPY_STRING_TYPE, + H5Array, XBackendArray, XDataArray, XZarrArrayWrapper, + ZarrArray, ) from ...compat import xarray as xr diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 23b5e1503..6752b0708 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -17,7 +17,6 @@ from pandas.api.types import is_numeric_dtype from scipy import sparse -import anndata from anndata import AnnData, ExperimentalFeatureWarning, Raw from anndata._core.aligned_mapping import AlignedMappingBase from anndata._core.sparse_dataset import BaseCompressedSparseDataset @@ -413,10 +412,6 @@ def gen_adata( # noqa: PLR0913 awkward_ragged=gen_awkward((12, None, None)), # U_recarray=gen_vstr_recarray(N, 5, "U4") ) - # https://github.com/zarr-developers/zarr-python/issues/2134 - # zarr v3 on-disk does not write structured dtypes - if anndata.settings.zarr_write_format == 3: - del uns["O_recarray"] with warnings.catch_warnings(): warnings.simplefilter("ignore", ExperimentalFeatureWarning) adata = AnnData( @@ -1153,6 +1148,9 @@ def __getitem__(self, key: str) -> bytes: else: class AccessTrackingStore(AccessTrackingStoreBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, read_only=True) + async def get( self, key: str, diff --git a/tests/lazy/conftest.py b/tests/lazy/conftest.py index 6e181c70b..e6c9bab4f 100644 --- a/tests/lazy/conftest.py +++ b/tests/lazy/conftest.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pytest +import zarr from scipy import sparse import anndata as ad @@ -126,7 +127,7 @@ def adata_remote_with_store_tall_skinny_path( worker_id: str = "serial", ) -> Path: orig_path = tmp_path_factory.mktemp(f"orig_{worker_id}.zarr") - M = 100_000 # forces zarr to chunk `obs` columns multiple ways - that way 1 access to `int64` below is actually only one access + M = 1000 N = 5 obs_names = pd.Index(f"cell{i}" for i in range(M)) var_names = pd.Index(f"gene{i}" for i in range(N)) @@ -139,6 +140,14 @@ def adata_remote_with_store_tall_skinny_path( ) orig.raw = orig.copy() orig.write_zarr(orig_path) + g = zarr.open_group(orig_path, mode="a", use_consolidated=False) + ad.io.write_elem( + g, + "obs", + obs, + dataset_kwargs=dict(chunks=(250,)), + ) + zarr.consolidate_metadata(g.store) return orig_path diff --git a/tests/lazy/test_read.py b/tests/lazy/test_read.py index acb4ee575..90909e689 100644 --- a/tests/lazy/test_read.py +++ b/tests/lazy/test_read.py @@ -67,8 +67,8 @@ def test_access_count_subset( ["obs/cat/codes", *non_obs_elem_names] ) adata_remote_tall_skinny[adata_remote_tall_skinny.obs["cat"] == "a", :] - # all codes read in for subset (from 1 chunk) - remote_store_tall_skinny.assert_access_count("obs/cat/codes", 1) + # all codes read in for subset (from 4 chunks as set in the fixture) + remote_store_tall_skinny.assert_access_count("obs/cat/codes", 4) for elem_name in non_obs_elem_names: remote_store_tall_skinny.assert_access_count(elem_name, 0) diff --git a/tests/test_backed_sparse.py b/tests/test_backed_sparse.py index e5899bb37..3a72d0d3f 100644 --- a/tests/test_backed_sparse.py +++ b/tests/test_backed_sparse.py @@ -388,7 +388,7 @@ def test_lazy_array_cache( store = AccessTrackingStore(path) for elem in elems: store.initialize_key_trackers([f"X/{elem}"]) - f = open_write_group(store, mode="r") + f = zarr.open_group(store, mode="r") a_disk = sparse_dataset(f["X"], should_cache_indptr=should_cache_indptr) a_disk[:1] a_disk[3:5] diff --git a/tests/test_structured_arrays.py b/tests/test_structured_arrays.py index a22fa526b..3787d38c8 100644 --- a/tests/test_structured_arrays.py +++ b/tests/test_structured_arrays.py @@ -1,11 +1,9 @@ from __future__ import annotations -from contextlib import nullcontext from itertools import combinations, product from typing import TYPE_CHECKING import numpy as np -import pytest import anndata as ad from anndata import AnnData @@ -45,24 +43,17 @@ def test_io( initial = AnnData(np.zeros((3, 3))) initial.uns = dict(str_rec=str_recarray, u_rec=u_recarray, s_rec=s_recarray) - with ( - pytest.raises( - NotImplementedError, match=r"zarr v3 does not support structured dtypes" - ) - if diskfmt == "zarr" and ad.settings.zarr_write_format == 3 - else nullcontext() - ): - write1(initial, filepth1) - disk_once = read1(filepth1) - write2(disk_once, filepth2) - disk_twice = read2(filepth2) + write1(initial, filepth1) + disk_once = read1(filepth1) + write2(disk_once, filepth2) + disk_twice = read2(filepth2) - adatas = [initial, disk_once, disk_twice] - keys = [ - "str_rec", - "u_rec", - # "s_rec" - ] + adatas = [initial, disk_once, disk_twice] + keys = [ + "str_rec", + "u_rec", + # "s_rec" + ] - for (ad1, key1), (ad2, key2) in combinations(product(adatas, keys), 2): - assert_str_contents_equal(ad1.uns[key1], ad2.uns[key2]) + for (ad1, key1), (ad2, key2) in combinations(product(adatas, keys), 2): + assert_str_contents_equal(ad1.uns[key1], ad2.uns[key2])