diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 9e9493de4..26b790b74 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -20,6 +20,9 @@ ZarrArray, ) +# Number of categories to show at head/tail in LazyCategoricalDtype repr +_N_CATEGORIES_REPR_SHOW = 10 + if TYPE_CHECKING: from pathlib import Path from typing import Literal @@ -27,7 +30,7 @@ from pandas._libs.missing import NAType from pandas.core.dtypes.base import ExtensionDtype - from anndata.compat import ZarrGroup + from anndata.compat import H5Group, ZarrGroup from ...compat import Index1DNorm @@ -36,6 +39,183 @@ from xarray.core.indexing import ExplicitIndexer +class LazyCategoricalDtype(pd.CategoricalDtype): + """A CategoricalDtype that lazily loads categories from zarr/h5 storage. + + This dtype provides efficient access to categorical metadata without loading + all categories into memory via :meth:`head_categories`, :meth:`tail_categories`, + and :attr:`n_categories`. Accessing :attr:`categories` will load all categories + into memory. + + Parameters + ---------- + categories_elem + The underlying zarr or h5 array (or group for nullable-string-array + encoding) containing category values. + ordered + Whether the categorical is ordered. + """ + + # Attributes that should be preserved during copying/pickling + _metadata = ("_categories_elem", "_ordered_flag") + + def __new__( + cls, + categories_elem: ZarrArray | H5Array | ZarrGroup | H5Group, + *, + ordered: bool = False, + ): + # Create instance without calling parent __init__ with categories + instance = object.__new__(cls) + return instance + + def __init__( + self, + categories_elem: ZarrArray | H5Array | ZarrGroup | H5Group, + *, + ordered: bool = False, + ): + self._categories_elem = categories_elem + self._ordered_flag = bool(ordered) + + def _get_categories_array(self) -> ZarrArray | H5Array: + """Get the underlying categories array. + + For string-array encoding: _categories_elem is directly the array. + For nullable-string-array encoding: _categories_elem would be a Group + with "values" key (not currently used for categories in anndata, but + handled defensively). + """ + if isinstance(self._categories_elem, (ZarrArray, H5Array)): + return self._categories_elem + # nullable-string-array encoding: Group with "values" and "mask" + return self._categories_elem["values"] + + @cached_property + def categories(self) -> pd.Index: + """Categories index. Loads all categories on first access and caches.""" + from anndata.io import read_elem + + return pd.Index(read_elem(self._categories_elem)) + + @property + def ordered(self) -> bool: + """Whether the categorical is ordered.""" + return self._ordered_flag + + @property + def n_categories(self) -> int: + """Number of categories (cheap, metadata only).""" + if "categories" in self.__dict__: + return len(self.categories) + return self._get_categories_array().shape[0] + + def _get_categories_slice( + self, n: int, *, from_end: bool = False + ) -> np.ndarray | pd.api.extensions.ExtensionArray: + """Get n categories from start or end. + + Parameters + ---------- + n + Number of categories to return. + from_end + If True, return last n categories. If False, return first n. + + Returns + ------- + np.ndarray or ExtensionArray + The requested categories. + """ + # If already fully loaded, slice from cache + if "categories" in self.__dict__: + sliced = self.categories[-n:] if from_end else self.categories[:n] + return np.asarray(sliced) + + # Read partial from disk + from anndata._io.specs.registry import read_elem_partial + + arr = self._get_categories_array() + total = arr.shape[0] + if from_end: + start, stop = max(total - n, 0), total + else: + start, stop = 0, min(n, total) + return read_elem_partial(arr, indices=slice(start, stop)) + + def head_categories( + self, n: int = 5 + ) -> np.ndarray | pd.api.extensions.ExtensionArray: + """Return first n categories without loading all into memory. + + Parameters + ---------- + n + Number of categories to return. Default 5. + + Returns + ------- + np.ndarray or ExtensionArray + The first n categories. + """ + return self._get_categories_slice(n, from_end=False) + + def tail_categories( + self, n: int = 5 + ) -> np.ndarray | pd.api.extensions.ExtensionArray: + """Return last n categories without loading all into memory. + + Parameters + ---------- + n + Number of categories to return. Default 5. + + Returns + ------- + np.ndarray or ExtensionArray + The last n categories. + """ + return self._get_categories_slice(n, from_end=True) + + def __repr__(self) -> str: + n_total = self.n_categories + ordered_str = ", ordered=True" if self.ordered else "" + + if n_total <= _N_CATEGORIES_REPR_SHOW * 2: + # Small enough to show all categories + if "categories" in self.__dict__: + cats = list(self.categories) + else: + cats = list(self.head_categories(n_total)) + return f"LazyCategoricalDtype(categories={cats!r}{ordered_str})" + + # Show truncated: first n ... last n + head = list(self.head_categories(_N_CATEGORIES_REPR_SHOW)) + tail = list(self.tail_categories(_N_CATEGORIES_REPR_SHOW)) + cats_display = [*head, "...", *tail] + return f"LazyCategoricalDtype(categories={cats_display!r}, n={n_total}{ordered_str})" + + def __hash__(self) -> int: + """Hash based on identity of underlying array and ordered flag. + + Required for use in sets and as dictionary keys (e.g., collecting + unique dtypes across AnnData objects). + """ + return hash((id(self._categories_elem), self._ordered_flag)) + + def __eq__(self, other) -> bool: + if isinstance(other, LazyCategoricalDtype): + has_same_ordering = self._ordered_flag == other._ordered_flag + are_arrays_equal = (self._categories_elem is other._categories_elem) or ( + self._get_categories_array() == other._get_categories_array() + ) + return has_same_ordering and are_arrays_equal + # Defer to pandas base implementation for all other comparisons + # This handles string comparison ("category"), CategoricalDtype comparisons, + # and all edge cases (None categories, ordered vs unordered, etc.) + return super().__eq__(other) + + class ZarrOrHDF5Wrapper[K: (H5Array | H5AsTypeView, ZarrArray)](XZarrArrayWrapper): def __init__(self, array: K) -> None: # AstypeView from h5py .astype() lacks chunks attribute @@ -87,7 +267,7 @@ class CategoricalArray[K: (H5Array, ZarrArray)](XBackendArray): """ _codes: ZarrOrHDF5Wrapper[K] - _categories: K + _categories_elem: K shape: tuple[int, ...] base_path_or_zarr_group: Path | ZarrGroup elem_name: str @@ -102,21 +282,22 @@ def __init__( ordered: bool, **kwargs, ): - self._categories = categories + self._categories_elem = categories self._ordered = ordered self._codes = ZarrOrHDF5Wrapper(codes) self.shape = self._codes.shape self.base_path_or_zarr_group = base_path_or_zarr_group self.file_format = "zarr" if isinstance(codes, ZarrArray) else "h5" self.elem_name = elem_name + # Create the lazy dtype - this is where categories are cached + self._lazy_dtype = LazyCategoricalDtype( + categories_elem=categories, ordered=ordered + ) - @cached_property - def categories(self) -> np.ndarray: - if isinstance(self._categories, ZarrArray): - return self._categories[...] - from anndata.io import read_elem - - return read_elem(self._categories) + @property + def categories(self) -> pd.Index | None: + """All categories. Loads and caches on first access.""" + return self._lazy_dtype.categories def __getitem__(self, key: ExplicitIndexer) -> PandasExtensionArray: from xarray.core.extension_array import PandasExtensionArray @@ -129,9 +310,10 @@ def __getitem__(self, key: ExplicitIndexer) -> PandasExtensionArray: categorical_array = categorical_array.remove_unused_categories() return PandasExtensionArray(categorical_array) - @cached_property - def dtype(self): - return pd.CategoricalDtype(categories=self.categories, ordered=self._ordered) + @property + def dtype(self) -> LazyCategoricalDtype: + """The dtype with lazy category loading support.""" + return self._lazy_dtype # circumvent https://github.com/tox-dev/sphinx-autodoc-typehints/issues/580 diff --git a/tests/lazy/test_read.py b/tests/lazy/test_read.py index 5b6c53650..23485aa68 100644 --- a/tests/lazy/test_read.py +++ b/tests/lazy/test_read.py @@ -2,7 +2,9 @@ from importlib.util import find_spec from typing import TYPE_CHECKING +from unittest.mock import patch +import h5py import numpy as np import pandas as pd import pytest @@ -111,16 +113,29 @@ def test_access_count_dtype( adata_remote_tall_skinny: AnnData, adata_remote_with_store_tall_skinny_path: Path, ) -> None: - remote_store_tall_skinny.initialize_key_trackers(["obs/cat/categories"]) remote_store_tall_skinny.assert_access_count("obs/cat/categories", 0) - # This should only cause categories to be read in once (and their mask if applicable) - adata_remote_tall_skinny.obs["cat"].dtype # noqa: B018 - remote_store_tall_skinny.assert_access_count("obs/cat/categories", 1) - adata_remote_tall_skinny.obs["cat"].dtype # noqa: B018 + # Accessing dtype alone should NOT load categories (lazy loading) adata_remote_tall_skinny.obs["cat"].dtype # noqa: B018 - remote_store_tall_skinny.assert_access_count("obs/cat/categories", 1) + remote_store_tall_skinny.assert_access_count("obs/cat/categories", 0) + + # n_categories should also be cheap (metadata only) + _ = adata_remote_tall_skinny.obs["cat"].dtype.n_categories + remote_store_tall_skinny.assert_access_count("obs/cat/categories", 0) + + # Accessing categories should trigger loading (once, then cached) + count_before = remote_store_tall_skinny.get_access_count("obs/cat/categories") + _ = adata_remote_tall_skinny.obs["cat"].dtype.categories + count_after = remote_store_tall_skinny.get_access_count("obs/cat/categories") + assert count_after > count_before, "categories access should trigger read" + + # Subsequent accesses should use cache (no additional reads) + _ = adata_remote_tall_skinny.obs["cat"].dtype.categories + _ = adata_remote_tall_skinny.obs["cat"].dtype.categories + assert ( + remote_store_tall_skinny.get_access_count("obs/cat/categories") == count_after + ), "cached categories should not trigger additional reads" def test_uns_uses_dask(adata_remote: AnnData): @@ -234,6 +249,350 @@ def test_chunks_df( assert arr.chunksize == expected_chunks +# Session-scoped categorical fixtures parametrized by (n_categories, ordered) +# Data is written once per session; stores are opened per-test with backend parametrization + +# Configuration: (name, n_categories, ordered, category_names) +_CAT_CONFIGS: list[tuple[str, int, bool, list[str] | None]] = [ + ("n3", 3, False, ["a", "b", "c"]), # basic tests, equality, hashing + ("n100", 100, False, None), # truncation, n_categories, head/tail + ("ordered", 3, True, ["low", "medium", "high"]), # ordered categories +] + + +@pytest.fixture(scope="session") +def cat_data_paths(tmp_path_factory) -> dict[tuple[str, str], Path]: + """Create all categorical test data once per session, return paths dict.""" + base = tmp_path_factory.mktemp("categorical_data") + paths: dict[tuple[str, str], Path] = {} + + for name, n_cat, ordered, cat_names in _CAT_CONFIGS: + categories = cat_names or [f"cat_{i:02d}" for i in range(n_cat)] + cat = pd.Categorical(categories, categories=categories, ordered=ordered) + + # Write zarr + zarr_path = base / f"{name}.zarr" + store = zarr.open(zarr_path, mode="w") + write_elem(store, "cat", cat) + paths[(name, "zarr")] = zarr_path + + # Write h5ad + h5_path = base / f"{name}.h5ad" + with h5py.File(h5_path, mode="w") as f: + write_elem(f, "cat", cat) + paths[(name, "h5ad")] = h5_path + + return paths + + +def _open_cat_store(path: Path, backend: str): + """Open categorical store for either backend.""" + if backend == "zarr": + return zarr.open(path, mode="r")["cat"] + return h5py.File(path, mode="r")["cat"] + + +def _make_cat_fixture(config_name: str): + """Factory to create categorical store fixtures with zarr/h5ad parametrization.""" + + @pytest.fixture(params=["zarr", "h5ad"]) + def _fixture(request, cat_data_paths): + path = cat_data_paths[(config_name, request.param)] + store = _open_cat_store(path, request.param) + yield store + if request.param == "h5ad": + store.file.close() + + return _fixture + + +cat_n3_store = _make_cat_fixture("n3") +cat_n100_store = _make_cat_fixture("n100") +cat_ordered_store = _make_cat_fixture("ordered") + + +def test_lazy_categorical_dtype_n_categories(cat_n100_store): + """Test n_categories is cheap (metadata only) and uses cache when loaded.""" + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + lazy_cat = read_elem_lazy(cat_n100_store) + dtype = lazy_cat.dtype + assert isinstance(dtype, LazyCategoricalDtype) + + # Before loading: n_categories should work without loading categories + assert "categories" not in dtype.__dict__ + assert dtype.n_categories == 100 + assert "categories" not in dtype.__dict__ # Still not loaded - proves metadata-only + assert dtype.ordered is False + + # After loading: n_categories should use cache + _ = dtype.categories # Force load + assert "categories" in dtype.__dict__ + assert dtype.n_categories == 100 # Uses cache now + + # Verify cache is used by modifying cached value + dtype.__dict__["categories"] = pd.Index(["x", "y", "z"]) + assert dtype.n_categories == 3 # Returns cached length, not disk length + + +def test_lazy_categorical_dtype_head_tail_categories(cat_n100_store): + """Test head_categories and tail_categories perform partial reads without loading all.""" + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + lazy_cat = read_elem_lazy(cat_n100_store) + dtype = lazy_cat.dtype + assert isinstance(dtype, LazyCategoricalDtype) + + # Verify categories not loaded initially + assert "categories" not in dtype.__dict__ + + # Test head_categories (first n) - should NOT load all categories + first5 = dtype.head_categories(5) + assert len(first5) == 5 + assert list(first5) == [f"cat_{i:02d}" for i in range(5)] + assert "categories" not in dtype.__dict__ # Still not fully loaded + + # Test head_categories default (first 5) + default_head = dtype.head_categories() + assert len(default_head) == 5 + assert list(default_head) == [f"cat_{i:02d}" for i in range(5)] + assert "categories" not in dtype.__dict__ # Still not fully loaded + + # Test tail_categories (last n) - should NOT load all categories + last3 = dtype.tail_categories(3) + assert len(last3) == 3 + assert list(last3) == [f"cat_{i:02d}" for i in range(97, 100)] + assert "categories" not in dtype.__dict__ # Still not fully loaded + + # Test tail_categories default (last 5) + default_tail = dtype.tail_categories() + assert len(default_tail) == 5 + assert list(default_tail) == [f"cat_{i:02d}" for i in range(95, 100)] + assert "categories" not in dtype.__dict__ # Still not fully loaded + + # Test requesting more than available + all_head = dtype.head_categories(200) + assert len(all_head) == 100 + assert list(all_head) == [f"cat_{i:02d}" for i in range(100)] + + all_tail = dtype.tail_categories(200) + assert len(all_tail) == 100 + assert list(all_tail) == [f"cat_{i:02d}" for i in range(100)] + + +def test_lazy_categorical_dtype_categories_caching(cat_n3_store): + """Test that categories are cached after full load.""" + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + lazy_cat = read_elem_lazy(cat_n3_store) + dtype = lazy_cat.dtype + assert isinstance(dtype, LazyCategoricalDtype) + + # Before loading, categories should not be cached (uses @cached_property) + assert "categories" not in dtype.__dict__ + + # Load categories + cats = dtype.categories + assert cats is not None + assert list(cats) == ["a", "b", "c"] + + # After loading, should be cached in __dict__ (cached_property pattern) + assert "categories" in dtype.__dict__ + + # Verify head/tail_categories use cache by modifying the cached value + dtype.__dict__["categories"] = pd.Index(["x", "y", "z", "w", "v"]) + head = dtype.head_categories(3) + assert list(head) == ["x", "y", "z"] # Returns cached values, not disk values + tail = dtype.tail_categories(3) + assert list(tail) == ["z", "w", "v"] # Returns cached values, not disk values + + +def test_lazy_categorical_dtype_ordered(cat_ordered_store): + """Test LazyCategoricalDtype with ordered categories.""" + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + lazy_cat = read_elem_lazy(cat_ordered_store) + dtype = lazy_cat.dtype + assert isinstance(dtype, LazyCategoricalDtype) + + assert dtype.ordered is True + assert dtype.n_categories == 3 + assert list(dtype.categories) == ["low", "medium", "high"] + + +def test_lazy_categorical_dtype_repr(cat_n100_store, cat_n3_store): + """Test LazyCategoricalDtype repr shows truncated categories.""" + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + # Test large number of categories (truncated repr) + lazy_cat = read_elem_lazy(cat_n100_store) + dtype = lazy_cat.dtype + assert isinstance(dtype, LazyCategoricalDtype) + + repr_str = repr(dtype) + assert "LazyCategoricalDtype" in repr_str + assert "n=100" in repr_str + assert "..." in repr_str # Truncation indicator + assert "cat_00" in repr_str # Head category + assert "cat_99" in repr_str # Tail category + + # Test small number of categories (full repr) + small_lazy_cat = read_elem_lazy(cat_n3_store) + small_dtype = small_lazy_cat.dtype + + small_repr = repr(small_dtype) + assert "LazyCategoricalDtype" in small_repr + assert "..." not in small_repr # No truncation for small categories + assert "'a'" in small_repr + assert "'b'" in small_repr + assert "'c'" in small_repr + + +def test_lazy_categorical_dtype_equality(cat_n3_store): + """Test LazyCategoricalDtype equality comparisons and basic properties.""" + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + lazy_cat = read_elem_lazy(cat_n3_store) + dtype = lazy_cat.dtype + assert isinstance(dtype, LazyCategoricalDtype) + + # Test name property (inherited from CategoricalDtype) + assert dtype.name == "category" + + # Test string comparison (dtype == "category") + assert dtype == "category" + assert dtype != "int64" + + # Test comparison with regular CategoricalDtype + regular_dtype = pd.CategoricalDtype(categories=["a", "b", "c"], ordered=False) + assert dtype == regular_dtype + + # Test comparison with different categories + different_dtype = pd.CategoricalDtype(categories=["x", "y", "z"], ordered=False) + assert dtype != different_dtype + + # Test comparison with different ordered flag + ordered_dtype = pd.CategoricalDtype(categories=["a", "b", "c"], ordered=True) + assert dtype != ordered_dtype + + # Test comparison with CategoricalDtype with None categories + # LazyCategoricalDtype always has categories, so should not equal None-categories dtype + dtype_none = pd.CategoricalDtype(categories=None, ordered=False) + assert dtype != dtype_none + + # Test comparison with non-CategoricalDtype + assert dtype != np.dtype("int64") + assert dtype != 123 + assert dtype is not None + + +@pytest.mark.parametrize("backend", ["zarr", "h5ad"]) +def test_lazy_categorical_dtype_equality_no_load( + cat_data_paths: dict[tuple[str, str], Path], backend: str +): + """Test same-location equality doesn't load category data. + + LazyCategoricalDtype uses location-based comparison to avoid loading categories: + - zarr: StorePath comparison + - h5py: HDF5 object ID comparison + + We patch read_elem to verify no data is loaded during comparison. + """ + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + path = cat_data_paths[("n3", backend)] + + if backend == "zarr": + open_store = lambda p: zarr.open(p, mode="r")["cat"] + else: + open_store = lambda p: h5py.File(p, mode="r")["cat"] + + store1 = open_store(path) + store2 = open_store(path) + dtype1 = read_elem_lazy(store1).dtype + dtype2 = read_elem_lazy(store2).dtype + + assert isinstance(dtype1, LazyCategoricalDtype) + assert isinstance(dtype2, LazyCategoricalDtype) + assert dtype1._categories_elem is not dtype2._categories_elem + + # Same-location comparison should NOT call read_elem + with patch("anndata.io.read_elem", side_effect=AssertionError("read_elem called")): + assert dtype1 == dtype2 + + # Positive control: comparison with regular CategoricalDtype DOES call read_elem + with ( + pytest.raises(AssertionError, match="read_elem called"), + patch("anndata.io.read_elem", side_effect=AssertionError("read_elem called")), + ): + dtype1 == pd.CategoricalDtype(categories=["a", "b", "c"]) # noqa: B015 + + if backend == "h5ad": + store1.file.close() + store2.file.close() + + +def test_lazy_categorical_roundtrip_via_anndata(tmp_path: Path): + """Integration test: lazy categorical through full AnnData workflow. + + This test uses the full AnnData read/write path rather than write_elem/read_elem_lazy + to verify end-to-end integration including dtype caching and equality. + """ + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + categories = ["type_a", "type_b", "type_c"] + adata = AnnData( + X=np.zeros((6, 2)), + obs=pd.DataFrame({ + "cat": pd.Categorical(categories * 2), + "ordered_cat": pd.Categorical( + ["low", "high"] * 3, + categories=["low", "high"], + ordered=True, + ), + }), + ) + + path = tmp_path / "test.zarr" + adata.write_zarr(path) + + # Read lazy and verify dtype + lazy = read_lazy(path) + dtype1 = lazy.obs["cat"].dtype + dtype2 = lazy.obs["cat"].dtype # Same underlying array + + assert isinstance(dtype1, LazyCategoricalDtype) + assert dtype1 is dtype2 # Same instance (cached) + assert dtype1 == dtype2 + + # Verify ordered categorical + ordered_dtype = lazy.obs["ordered_cat"].dtype + assert isinstance(ordered_dtype, LazyCategoricalDtype) + assert ordered_dtype.ordered is True + + # Round-trip: lazy -> memory should equal original + loaded = lazy.to_memory() + assert loaded.obs["cat"].equals(adata.obs["cat"]) + assert loaded.obs["ordered_cat"].equals(adata.obs["ordered_cat"]) + + +def test_lazy_categorical_dtype_hash(cat_n3_store): + """Test LazyCategoricalDtype is hashable.""" + from anndata.experimental.backed._lazy_arrays import LazyCategoricalDtype + + lazy_cat = read_elem_lazy(cat_n3_store) + dtype = lazy_cat.dtype + assert isinstance(dtype, LazyCategoricalDtype) + + # Should be hashable (useful for collecting unique dtypes in sets/dicts) + h = hash(dtype) + assert isinstance(h, int) + + # Can be used in a set + s = {dtype} + assert dtype in s + + def test_nullable_string_index_decoding(tmp_path: Path): """Test that nullable string indices are properly decoded from bytes.