Skip to content

Commit 2f2edda

Browse files
authored
feat: AnnData.unwriteable based on AnnData._reduce + iter_outer + refactorings of other relevant functions (#2372)
1 parent 40e6ab1 commit 2f2edda

10 files changed

Lines changed: 313 additions & 89 deletions

File tree

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Writing a complete {class}`AnnData` object to disk in anndata’s native formats
9292
9393
AnnData.write_h5ad
9494
AnnData.write_zarr
95+
AnnData.unwriteable
9596
9697
9798
..

docs/concatenation.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Let's start off with an example:
2626
AnnData object with n_obs × n_vars = 700 × 765
2727
obs: 'bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score', 'G2M_score', 'phase', 'louvain'
2828
var: 'n_counts', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'
29-
uns: 'bulk_labels_colors', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups'
3029
obsm: 'X_pca', 'X_umap'
3130
varm: 'PCs'
3231
obsp: ...
@@ -165,9 +164,9 @@ First, our example case:
165164
>>> blobs
166165
AnnData object with n_obs × n_vars = 640 × 30
167166
obs: 'blobs'
168-
uns: 'pca'
169167
obsm: 'X_pca'
170168
varm: 'PCs'
169+
uns: 'pca'
171170

172171
Now we will split this object by the categorical `"blobs"` and recombine it to illustrate different merge strategies.
173172

@@ -181,9 +180,9 @@ Now we will split this object by the categorical `"blobs"` and recombine it to i
181180
>>> adatas[0]
182181
AnnData object with n_obs × n_vars = 128 × 30
183182
obs: 'blobs'
184-
uns: 'pca'
185183
obsm: 'X_pca', 'qc'
186184
varm: 'PCs', '0_qc'
185+
uns: 'pca'
187186

188187
`adatas` is now a list of datasets with disjoint sets of observations and a common set of variables.
189188
Each object has had QC metrics computed, with observation-wise metrics stored under `"qc"` in `.obsm`, and variable-wise metrics stored with a unique key for each subset.

docs/release-notes/2372.feat.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
New {meth}`anndata.AnnData.unwriteable` for checking if an `AnnData` can be written {user}`ilan-gold`

src/anndata/_core/anndata.py

Lines changed: 143 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from collections import OrderedDict
7+
from collections import OrderedDict, defaultdict
88
from collections.abc import Mapping, MutableMapping, Sequence
99
from copy import copy, deepcopy
1010
from functools import singledispatchmethod
@@ -26,8 +26,10 @@
2626
from .. import utils
2727
from .._settings import settings
2828
from ..compat import (
29+
AwkArray,
2930
DaskArray,
3031
IndexManager,
32+
XDataset,
3133
ZarrArray,
3234
_move_adj_mtx,
3335
has_xp,
@@ -39,6 +41,7 @@
3941
axis_len,
4042
deprecation_msg,
4143
ensure_df_homogeneous,
44+
iter_outer,
4245
raise_value_error_if_multiindex_columns,
4346
set_module,
4447
warn,
@@ -62,9 +65,12 @@
6265
from scipy import sparse
6366
from zarr.storage import StoreLike
6467

68+
from anndata.typing import RWAble
69+
70+
from .._types import ReduceFunc
6571
from ..acc import AdRef, Array, MapAcc, RefAcc
66-
from ..compat import XDataset
67-
from ..typing import Index, Index1D, _Index1DNorm, _XDataType
72+
from ..compat import CSArray, CSMatrix
73+
from ..typing import AxisStorable, Index, Index1D, _Index1DNorm, _XDataType
6874
from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
6975

7076

@@ -512,53 +518,54 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915
512518
def __sizeof__(
513519
self, *, show_stratified: bool = False, with_disk: bool = False
514520
) -> int:
515-
def get_size(X) -> int:
516-
def cs_to_bytes(X) -> int:
517-
return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes)
521+
def cs_to_bytes(X: CSArray | CSMatrix) -> int:
522+
return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes)
518523

524+
def get_size(X: RWAble) -> int:
519525
if isinstance(X, h5py.Dataset) and with_disk:
520526
return int(np.array(X.shape).prod() * X.dtype.itemsize)
521527
elif isinstance(X, BaseCompressedSparseDataset) and with_disk:
522528
return cs_to_bytes(X._to_backed())
523529
elif issparse(X):
524530
return cs_to_bytes(X)
531+
elif isinstance(X, dict | MutableMapping):
532+
return sum(get_size(v) for v in X.values())
525533
else:
526534
return X.__sizeof__()
527535

528-
sizes = {}
529-
attrs = ["X", "_obs", "_var"]
530-
attrs_multi = ["_uns", "_obsm", "_varm", "varp", "_obsp", "_layers"]
531-
for attr in attrs + attrs_multi:
532-
if attr in attrs_multi:
533-
keys = getattr(self, attr).keys()
534-
s = sum(get_size(getattr(self, attr)[k]) for k in keys)
536+
def fold_size(
537+
elem: _XDataType | AxisStorable | pd.DataFrame | XDataset,
538+
*,
539+
accumulate: dict[str, int],
540+
attr_name: str | None, # TODO: type
541+
):
542+
if elem is None:
543+
size = 0
544+
elif elem is self.raw:
545+
size = (
546+
get_size(elem.X)
547+
+ get_size(elem.var)
548+
+ sum(get_size(v) for v in elem.varm.values())
549+
)
535550
else:
536-
s = get_size(getattr(self, attr))
537-
if s > 0 and show_stratified:
551+
size = get_size(elem)
552+
accumulate[attr_name] = size
553+
if size > 0 and show_stratified:
538554
from tqdm import tqdm
539555

540-
print(
541-
f"Size of {attr.replace('_', '.'):<7}: {tqdm.format_sizeof(s, 'B')}"
542-
)
543-
sizes[attr] = s
544-
return sum(sizes.values())
556+
print(f"Size of {attr_name}: {tqdm.format_sizeof(size, 'B')}")
557+
return accumulate
558+
559+
return sum(self._reduce(fold_size, init=defaultdict(int)).values())
545560

546561
def _gen_repr(self, n_obs, n_vars) -> str:
547562
backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else ""
548563
descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}"
549-
for attr in [
550-
"obs",
551-
"var",
552-
"uns",
553-
"obsm",
554-
"varm",
555-
"layers",
556-
"obsp",
557-
"varp",
558-
]:
559-
keys = getattr(self, attr).keys()
560-
if len(keys) > 0:
561-
descr += f"\n {attr}: {str(list(keys))[1:-1]}"
564+
for attr_name, elem in iter_outer(self):
565+
if attr_name not in {"raw", "X"}:
566+
keys = elem.keys()
567+
if len(keys) > 0:
568+
descr += f"\n {attr_name}: {str(list(keys))[1:-1]}"
562569
return descr
563570

564571
def __repr__(self) -> str:
@@ -1383,27 +1390,16 @@ def to_memory(self, *, copy: bool = False) -> AnnData:
13831390
mem = backed[backed.obs["cluster"] == "a", :].to_memory()
13841391
"""
13851392
new = {}
1386-
for attr_name in [
1387-
"X",
1388-
"obs",
1389-
"var",
1390-
"obsm",
1391-
"varm",
1392-
"obsp",
1393-
"varp",
1394-
"layers",
1395-
"uns",
1396-
]:
1397-
attr = getattr(self, attr_name, None)
1393+
for attr_name, attr in iter_outer(self):
13981394
if attr is not None:
1399-
new[attr_name] = to_memory(attr, copy=copy)
1400-
1401-
if self.raw is not None:
1402-
new["raw"] = {
1403-
"X": to_memory(self.raw.X, copy=copy),
1404-
"var": to_memory(self.raw.var, copy=copy),
1405-
"varm": to_memory(self.raw.varm, copy=copy),
1406-
}
1395+
if attr is self.raw:
1396+
new["raw"] = {
1397+
"X": to_memory(self.raw.X, copy=copy),
1398+
"var": to_memory(self.raw.var, copy=copy),
1399+
"varm": to_memory(self.raw.varm, copy=copy),
1400+
}
1401+
else:
1402+
new[attr_name] = to_memory(attr, copy=copy)
14071403

14081404
if self.isbacked:
14091405
self.file.close()
@@ -1436,6 +1432,100 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData:
14361432
write_h5ad(filename, self)
14371433
return read_h5ad(filename, backed=mode)
14381434

1435+
def _reduce[T](
1436+
self,
1437+
func: ReduceFunc[T],
1438+
*,
1439+
init: T,
1440+
) -> T:
1441+
"""Accumulate a value starting from init by iterating over the parent "elems"of the AnnData object i.e., raw, obs, varp etc.
1442+
1443+
Parameters
1444+
----------
1445+
func
1446+
The function that performs the accumulation.
1447+
init
1448+
The starting value
1449+
1450+
Returns
1451+
-------
1452+
An accumulated value
1453+
"""
1454+
accumulate = init
1455+
for attr_name, attr in iter_outer(self):
1456+
accumulate = func(attr, accumulate=accumulate, attr_name=attr_name)
1457+
return accumulate
1458+
1459+
def unwriteable(self, *, store_type: Literal["h5", "zarr"] | None) -> bool:
1460+
"""Whether or not an `AnnData` object can be written to disk for a given store type.
1461+
1462+
Parameters
1463+
----------
1464+
store_type
1465+
Which backing store - `None` indicates that it can be writeable to either.
1466+
1467+
Returns
1468+
-------
1469+
Whether or not this object is writeable.
1470+
While the return type may change to include richer output about which elements cannot be written,
1471+
this new type's evaluation as a boolean will not change from the current behavior i.e.,
1472+
`bool(adata.unwriteable())` will always evaluate the same.
1473+
"""
1474+
1475+
from anndata._io.specs.registry import _REGISTRY
1476+
1477+
writeable_elems = {
1478+
src_type
1479+
for (dest_type, src_type, __) in _REGISTRY.write
1480+
if store_type is None or store_type in dest_type.__module__
1481+
}
1482+
1483+
def predicate( # noqa: PLR0911
1484+
elem: RWAble,
1485+
*,
1486+
accumulate: bool,
1487+
attr_name: str | None = None, # TODO: type
1488+
):
1489+
if elem is None:
1490+
return accumulate
1491+
if isinstance(elem, AnnData):
1492+
return accumulate and elem.unwriteable(store_type=store_type)
1493+
if isinstance(elem, pd.Categorical):
1494+
return accumulate and predicate(elem.categories, accumulate=accumulate)
1495+
if isinstance(elem, pd.Series | pd.Index):
1496+
# matches behavior in methods.py
1497+
return accumulate and predicate(elem._values, accumulate=accumulate)
1498+
if isinstance(elem, AwkArray):
1499+
import awkward as ak
1500+
1501+
container = ak.to_buffers(ak.to_packed(elem))
1502+
return accumulate and all(
1503+
predicate(v, accumulate=accumulate) for v in container[2].values()
1504+
)
1505+
if attr_name == "raw":
1506+
accumulate = accumulate and type(elem.X) in writeable_elems
1507+
return accumulate and all(
1508+
predicate(e[attr], accumulate=accumulate)
1509+
for e in [elem.var, elem.varm]
1510+
for attr in e
1511+
)
1512+
if attr_name in {
1513+
"obs",
1514+
"obsm",
1515+
"varm",
1516+
"var",
1517+
"layers",
1518+
"varp",
1519+
"obsp",
1520+
"uns",
1521+
} or isinstance(elem, pd.DataFrame | XDataset | MutableMapping):
1522+
return accumulate and all(
1523+
predicate(elem[k], accumulate=accumulate) for k in elem
1524+
)
1525+
return accumulate and type(elem) in writeable_elems
1526+
1527+
return self._reduce(predicate, init=True)
1528+
14391529
def var_names_make_unique(self, join: str = "-") -> None:
14401530
# Important to go through the setter so obsm dataframes are updated too
14411531
self.var_names = utils.make_index_unique(self.var.index, join)

src/anndata/_io/h5ad.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import re
4+
from collections.abc import MutableMapping
45
from functools import partial
56
from pathlib import Path
67
from types import MappingProxyType
@@ -23,7 +24,7 @@
2324
_from_fixed_length_strings,
2425
)
2526
from ..experimental import read_dispatched
26-
from ..utils import warn
27+
from ..utils import iter_outer, warn
2728
from .specs import read_elem, write_elem
2829
from .specs.registry import IOSpec, write_spec
2930
from .utils import (
@@ -84,23 +85,26 @@ def write_h5ad(
8485
f = cast("h5py.Group", f["/"])
8586
f.attrs.setdefault("encoding-type", "anndata")
8687
f.attrs.setdefault("encoding-version", "0.1.0")
87-
88-
_write_x(
89-
f,
90-
adata, # accessing adata.X reopens adata.file if it’s backed
91-
is_backed=adata.isbacked and adata.filename == filepath,
92-
as_dense=as_dense,
93-
dataset_kwargs=dataset_kwargs,
94-
)
95-
_write_raw(f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs)
96-
write_elem(f, "obs", adata.obs, dataset_kwargs=dataset_kwargs)
97-
write_elem(f, "var", adata.var, dataset_kwargs=dataset_kwargs)
98-
write_elem(f, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs)
99-
write_elem(f, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs)
100-
write_elem(f, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs)
101-
write_elem(f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs)
102-
write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs)
103-
write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs)
88+
for k, elem in iter_outer(adata):
89+
if k == "X":
90+
_write_x(
91+
f,
92+
adata, # accessing adata.X reopens adata.file if it’s backed
93+
is_backed=adata.isbacked and adata.filename == filepath,
94+
as_dense=as_dense,
95+
dataset_kwargs=dataset_kwargs,
96+
)
97+
elif k == "raw":
98+
_write_raw(
99+
f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs
100+
)
101+
else:
102+
write_elem(
103+
f,
104+
k,
105+
dict(elem) if isinstance(elem, MutableMapping) else elem,
106+
dataset_kwargs=dataset_kwargs,
107+
)
104108

105109

106110
def _write_x(

0 commit comments

Comments
 (0)