Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4b5d7ab
implement LazyCategoricalDtype
katosh Jan 8, 2026
7edb510
fix: LazyCategoricalDtype.__eq__ handle string comparison
katosh Jan 8, 2026
90ac52e
increase testing coverage of LazyCategoricalDtype
katosh Jan 8, 2026
667c823
Merge remote-tracking branch 'origin/main' into feat/lazy-categorical…
katosh Jan 8, 2026
c6a68da
manipulate cache for better testing
katosh Jan 8, 2026
92b7bfa
refactor(LazyCategoricalDtype): implement review suggestions
katosh Jan 8, 2026
03fe0b0
remove unnecessary docstring
katosh Jan 8, 2026
b57bdab
remove remaining docstring examples
katosh Jan 8, 2026
cc06639
refactor(LazyCategoricalDtype): address second round of review feedback
katosh Jan 9, 2026
9ff164c
test: refactor LazyCategoricalDtype tests to use write_elem/read_elem…
katosh Jan 9, 2026
d1c4d46
address third round review: simplify __eq__, improve repr, refactor f…
katosh Jan 12, 2026
3d8bbea
fix linting and simplify __eq__ using zarr/h5py built-in location equ…
katosh Jan 12, 2026
e8ee005
test: add same-location equality check for LazyCategoricalDtype
katosh Jan 12, 2026
e57ffb0
test: improve LazyCategoricalDtype equality test to verify no I/O
katosh Jan 13, 2026
f4b6cd9
test: parametrize all LazyCategoricalDtype tests for both zarr and h5ad
katosh Jan 13, 2026
175aef6
test: consolidate LazyCategoricalDtype tests and add no-load verifica…
katosh Jan 13, 2026
6259d14
test: fix misleading comment about hash requirement
katosh Jan 13, 2026
87d399b
simplify LazyCategoricalDtype comparison
katosh Jan 23, 2026
f740784
increase number of preview categories in LazyCategoricalDtype
katosh Jan 23, 2026
ac1cab5
test: consolidate categorical fixtures with factory pattern
katosh Jan 23, 2026
edb04fc
test: improve equality_no_load test with read_elem patching
katosh Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 195 additions & 13 deletions src/anndata/experimental/backed/_lazy_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
ZarrArray,
)

# Number of categories to show at head/tail in LazyCategoricalDtype repr
_N_CATEGORIES_REPR_SHOW = 10

if TYPE_CHECKING:
from pathlib import Path
from typing import Literal

from pandas._libs.missing import NAType
from pandas.core.dtypes.base import ExtensionDtype

from anndata.compat import ZarrGroup
from anndata.compat import H5Group, ZarrGroup

from ...compat import Index1DNorm

Expand All @@ -36,6 +39,183 @@
from xarray.core.indexing import ExplicitIndexer


class LazyCategoricalDtype(pd.CategoricalDtype):
"""A CategoricalDtype that lazily loads categories from zarr/h5 storage.

This dtype provides efficient access to categorical metadata without loading
all categories into memory via :meth:`head_categories`, :meth:`tail_categories`,
and :attr:`n_categories`. Accessing :attr:`categories` will load all categories
into memory.

Parameters
----------
categories_elem
The underlying zarr or h5 array (or group for nullable-string-array
encoding) containing category values.
ordered
Whether the categorical is ordered.
"""

# Attributes that should be preserved during copying/pickling
_metadata = ("_categories_elem", "_ordered_flag")

def __new__(
cls,
categories_elem: ZarrArray | H5Array | ZarrGroup | H5Group,
*,
ordered: bool = False,
):
# Create instance without calling parent __init__ with categories
instance = object.__new__(cls)
return instance

def __init__(
self,
categories_elem: ZarrArray | H5Array | ZarrGroup | H5Group,
*,
ordered: bool = False,
):
self._categories_elem = categories_elem
self._ordered_flag = bool(ordered)

def _get_categories_array(self) -> ZarrArray | H5Array:
"""Get the underlying categories array.

For string-array encoding: _categories_elem is directly the array.
For nullable-string-array encoding: _categories_elem would be a Group
with "values" key (not currently used for categories in anndata, but
handled defensively).
"""
if isinstance(self._categories_elem, (ZarrArray, H5Array)):
return self._categories_elem
# nullable-string-array encoding: Group with "values" and "mask"
return self._categories_elem["values"]

@cached_property
def categories(self) -> pd.Index:
"""Categories index. Loads all categories on first access and caches."""
from anndata.io import read_elem

return pd.Index(read_elem(self._categories_elem))

@property
def ordered(self) -> bool:
"""Whether the categorical is ordered."""
return self._ordered_flag

@property
def n_categories(self) -> int:
"""Number of categories (cheap, metadata only)."""
if "categories" in self.__dict__:
return len(self.categories)
return self._get_categories_array().shape[0]

def _get_categories_slice(
self, n: int, *, from_end: bool = False
) -> np.ndarray | pd.api.extensions.ExtensionArray:
"""Get n categories from start or end.

Parameters
----------
n
Number of categories to return.
from_end
If True, return last n categories. If False, return first n.

Returns
-------
np.ndarray or ExtensionArray
The requested categories.
"""
# If already fully loaded, slice from cache
if "categories" in self.__dict__:
sliced = self.categories[-n:] if from_end else self.categories[:n]
return np.asarray(sliced)

# Read partial from disk
from anndata._io.specs.registry import read_elem_partial

arr = self._get_categories_array()
total = arr.shape[0]
if from_end:
start, stop = max(total - n, 0), total
else:
start, stop = 0, min(n, total)
return read_elem_partial(arr, indices=slice(start, stop))

def head_categories(
self, n: int = 5
) -> np.ndarray | pd.api.extensions.ExtensionArray:
"""Return first n categories without loading all into memory.

Parameters
----------
n
Number of categories to return. Default 5.

Returns
-------
np.ndarray or ExtensionArray
The first n categories.
"""
return self._get_categories_slice(n, from_end=False)

def tail_categories(
self, n: int = 5
) -> np.ndarray | pd.api.extensions.ExtensionArray:
"""Return last n categories without loading all into memory.

Parameters
----------
n
Number of categories to return. Default 5.

Returns
-------
np.ndarray or ExtensionArray
The last n categories.
"""
return self._get_categories_slice(n, from_end=True)

def __repr__(self) -> str:
n_total = self.n_categories
ordered_str = ", ordered=True" if self.ordered else ""

if n_total <= _N_CATEGORIES_REPR_SHOW * 2:
# Small enough to show all categories
if "categories" in self.__dict__:
cats = list(self.categories)
else:
cats = list(self.head_categories(n_total))
return f"LazyCategoricalDtype(categories={cats!r}{ordered_str})"

# Show truncated: first n ... last n
head = list(self.head_categories(_N_CATEGORIES_REPR_SHOW))
tail = list(self.tail_categories(_N_CATEGORIES_REPR_SHOW))
cats_display = [*head, "...", *tail]
return f"LazyCategoricalDtype(categories={cats_display!r}, n={n_total}{ordered_str})"

def __hash__(self) -> int:
"""Hash based on identity of underlying array and ordered flag.

Required for use in sets and as dictionary keys (e.g., collecting
unique dtypes across AnnData objects).
"""
return hash((id(self._categories_elem), self._ordered_flag))

def __eq__(self, other) -> bool:
if isinstance(other, LazyCategoricalDtype):
has_same_ordering = self._ordered_flag == other._ordered_flag
are_arrays_equal = (self._categories_elem is other._categories_elem) or (
self._get_categories_array() == other._get_categories_array()
)
return has_same_ordering and are_arrays_equal
# Defer to pandas base implementation for all other comparisons
# This handles string comparison ("category"), CategoricalDtype comparisons,
# and all edge cases (None categories, ordered vs unordered, etc.)
return super().__eq__(other)


class ZarrOrHDF5Wrapper[K: (H5Array | H5AsTypeView, ZarrArray)](XZarrArrayWrapper):
def __init__(self, array: K) -> None:
# AstypeView from h5py .astype() lacks chunks attribute
Expand Down Expand Up @@ -87,7 +267,7 @@ class CategoricalArray[K: (H5Array, ZarrArray)](XBackendArray):
"""

_codes: ZarrOrHDF5Wrapper[K]
_categories: K
_categories_elem: K
shape: tuple[int, ...]
base_path_or_zarr_group: Path | ZarrGroup
elem_name: str
Expand All @@ -102,21 +282,22 @@ def __init__(
ordered: bool,
**kwargs,
):
self._categories = categories
self._categories_elem = categories
self._ordered = ordered
self._codes = ZarrOrHDF5Wrapper(codes)
self.shape = self._codes.shape
self.base_path_or_zarr_group = base_path_or_zarr_group
self.file_format = "zarr" if isinstance(codes, ZarrArray) else "h5"
self.elem_name = elem_name
# Create the lazy dtype - this is where categories are cached
self._lazy_dtype = LazyCategoricalDtype(
categories_elem=categories, ordered=ordered
)

@cached_property
def categories(self) -> np.ndarray:
if isinstance(self._categories, ZarrArray):
return self._categories[...]
from anndata.io import read_elem

return read_elem(self._categories)
@property
def categories(self) -> pd.Index | None:
"""All categories. Loads and caches on first access."""
return self._lazy_dtype.categories

def __getitem__(self, key: ExplicitIndexer) -> PandasExtensionArray:
from xarray.core.extension_array import PandasExtensionArray
Expand All @@ -129,9 +310,10 @@ def __getitem__(self, key: ExplicitIndexer) -> PandasExtensionArray:
categorical_array = categorical_array.remove_unused_categories()
return PandasExtensionArray(categorical_array)

@cached_property
def dtype(self):
return pd.CategoricalDtype(categories=self.categories, ordered=self._ordered)
@property
def dtype(self) -> LazyCategoricalDtype:
"""The dtype with lazy category loading support."""
return self._lazy_dtype


# circumvent https://github.com/tox-dev/sphinx-autodoc-typehints/issues/580
Expand Down
Loading
Loading