|
4 | 4 |
|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | | -from collections import OrderedDict |
| 7 | +from collections import OrderedDict, defaultdict |
8 | 8 | from collections.abc import Mapping, MutableMapping, Sequence |
9 | 9 | from copy import copy, deepcopy |
10 | 10 | from functools import singledispatchmethod |
|
26 | 26 | from .. import utils |
27 | 27 | from .._settings import settings |
28 | 28 | from ..compat import ( |
| 29 | + AwkArray, |
29 | 30 | DaskArray, |
30 | 31 | IndexManager, |
| 32 | + XDataset, |
31 | 33 | ZarrArray, |
32 | 34 | _move_adj_mtx, |
33 | 35 | has_xp, |
|
39 | 41 | axis_len, |
40 | 42 | deprecation_msg, |
41 | 43 | ensure_df_homogeneous, |
| 44 | + iter_outer, |
42 | 45 | raise_value_error_if_multiindex_columns, |
43 | 46 | set_module, |
44 | 47 | warn, |
|
62 | 65 | from scipy import sparse |
63 | 66 | from zarr.storage import StoreLike |
64 | 67 |
|
| 68 | + from anndata.typing import RWAble |
| 69 | + |
| 70 | + from .._types import ReduceFunc |
65 | 71 | from ..acc import AdRef, Array, MapAcc, RefAcc |
66 | | - from ..compat import XDataset |
67 | | - from ..typing import Index, Index1D, _Index1DNorm, _XDataType |
| 72 | + from ..compat import CSArray, CSMatrix |
| 73 | + from ..typing import AxisStorable, Index, Index1D, _Index1DNorm, _XDataType |
68 | 74 | from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView |
69 | 75 |
|
70 | 76 |
|
@@ -512,53 +518,54 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 |
512 | 518 | def __sizeof__( |
513 | 519 | self, *, show_stratified: bool = False, with_disk: bool = False |
514 | 520 | ) -> int: |
515 | | - def get_size(X) -> int: |
516 | | - def cs_to_bytes(X) -> int: |
517 | | - return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) |
| 521 | + def cs_to_bytes(X: CSArray | CSMatrix) -> int: |
| 522 | + return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) |
518 | 523 |
|
| 524 | + def get_size(X: RWAble) -> int: |
519 | 525 | if isinstance(X, h5py.Dataset) and with_disk: |
520 | 526 | return int(np.array(X.shape).prod() * X.dtype.itemsize) |
521 | 527 | elif isinstance(X, BaseCompressedSparseDataset) and with_disk: |
522 | 528 | return cs_to_bytes(X._to_backed()) |
523 | 529 | elif issparse(X): |
524 | 530 | return cs_to_bytes(X) |
| 531 | + elif isinstance(X, dict | MutableMapping): |
| 532 | + return sum(get_size(v) for v in X.values()) |
525 | 533 | else: |
526 | 534 | return X.__sizeof__() |
527 | 535 |
|
528 | | - sizes = {} |
529 | | - attrs = ["X", "_obs", "_var"] |
530 | | - attrs_multi = ["_uns", "_obsm", "_varm", "varp", "_obsp", "_layers"] |
531 | | - for attr in attrs + attrs_multi: |
532 | | - if attr in attrs_multi: |
533 | | - keys = getattr(self, attr).keys() |
534 | | - s = sum(get_size(getattr(self, attr)[k]) for k in keys) |
| 536 | + def fold_size( |
| 537 | + elem: _XDataType | AxisStorable | pd.DataFrame | XDataset, |
| 538 | + *, |
| 539 | + accumulate: dict[str, int], |
| 540 | + attr_name: str | None, # TODO: type |
| 541 | + ): |
| 542 | + if elem is None: |
| 543 | + size = 0 |
| 544 | + elif elem is self.raw: |
| 545 | + size = ( |
| 546 | + get_size(elem.X) |
| 547 | + + get_size(elem.var) |
| 548 | + + sum(get_size(v) for v in elem.varm.values()) |
| 549 | + ) |
535 | 550 | else: |
536 | | - s = get_size(getattr(self, attr)) |
537 | | - if s > 0 and show_stratified: |
| 551 | + size = get_size(elem) |
| 552 | + accumulate[attr_name] = size |
| 553 | + if size > 0 and show_stratified: |
538 | 554 | from tqdm import tqdm |
539 | 555 |
|
540 | | - print( |
541 | | - f"Size of {attr.replace('_', '.'):<7}: {tqdm.format_sizeof(s, 'B')}" |
542 | | - ) |
543 | | - sizes[attr] = s |
544 | | - return sum(sizes.values()) |
| 556 | + print(f"Size of {attr_name}: {tqdm.format_sizeof(size, 'B')}") |
| 557 | + return accumulate |
| 558 | + |
| 559 | + return sum(self._reduce(fold_size, init=defaultdict(int)).values()) |
545 | 560 |
|
546 | 561 | def _gen_repr(self, n_obs, n_vars) -> str: |
547 | 562 | backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" |
548 | 563 | descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}" |
549 | | - for attr in [ |
550 | | - "obs", |
551 | | - "var", |
552 | | - "uns", |
553 | | - "obsm", |
554 | | - "varm", |
555 | | - "layers", |
556 | | - "obsp", |
557 | | - "varp", |
558 | | - ]: |
559 | | - keys = getattr(self, attr).keys() |
560 | | - if len(keys) > 0: |
561 | | - descr += f"\n {attr}: {str(list(keys))[1:-1]}" |
| 564 | + for attr_name, elem in iter_outer(self): |
| 565 | + if attr_name not in {"raw", "X"}: |
| 566 | + keys = elem.keys() |
| 567 | + if len(keys) > 0: |
| 568 | + descr += f"\n {attr_name}: {str(list(keys))[1:-1]}" |
562 | 569 | return descr |
563 | 570 |
|
564 | 571 | def __repr__(self) -> str: |
@@ -1383,27 +1390,16 @@ def to_memory(self, *, copy: bool = False) -> AnnData: |
1383 | 1390 | mem = backed[backed.obs["cluster"] == "a", :].to_memory() |
1384 | 1391 | """ |
1385 | 1392 | new = {} |
1386 | | - for attr_name in [ |
1387 | | - "X", |
1388 | | - "obs", |
1389 | | - "var", |
1390 | | - "obsm", |
1391 | | - "varm", |
1392 | | - "obsp", |
1393 | | - "varp", |
1394 | | - "layers", |
1395 | | - "uns", |
1396 | | - ]: |
1397 | | - attr = getattr(self, attr_name, None) |
| 1393 | + for attr_name, attr in iter_outer(self): |
1398 | 1394 | if attr is not None: |
1399 | | - new[attr_name] = to_memory(attr, copy=copy) |
1400 | | - |
1401 | | - if self.raw is not None: |
1402 | | - new["raw"] = { |
1403 | | - "X": to_memory(self.raw.X, copy=copy), |
1404 | | - "var": to_memory(self.raw.var, copy=copy), |
1405 | | - "varm": to_memory(self.raw.varm, copy=copy), |
1406 | | - } |
| 1395 | + if attr is self.raw: |
| 1396 | + new["raw"] = { |
| 1397 | + "X": to_memory(self.raw.X, copy=copy), |
| 1398 | + "var": to_memory(self.raw.var, copy=copy), |
| 1399 | + "varm": to_memory(self.raw.varm, copy=copy), |
| 1400 | + } |
| 1401 | + else: |
| 1402 | + new[attr_name] = to_memory(attr, copy=copy) |
1407 | 1403 |
|
1408 | 1404 | if self.isbacked: |
1409 | 1405 | self.file.close() |
@@ -1436,6 +1432,100 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: |
1436 | 1432 | write_h5ad(filename, self) |
1437 | 1433 | return read_h5ad(filename, backed=mode) |
1438 | 1434 |
|
| 1435 | + def _reduce[T]( |
| 1436 | + self, |
| 1437 | + func: ReduceFunc[T], |
| 1438 | + *, |
| 1439 | + init: T, |
| 1440 | + ) -> T: |
| 1441 | + """Accumulate a value starting from init by iterating over the parent "elems"of the AnnData object i.e., raw, obs, varp etc. |
| 1442 | +
|
| 1443 | + Parameters |
| 1444 | + ---------- |
| 1445 | + func |
| 1446 | + The function that performs the accumulation. |
| 1447 | + init |
| 1448 | + The starting value |
| 1449 | +
|
| 1450 | + Returns |
| 1451 | + ------- |
| 1452 | + An accumulated value |
| 1453 | + """ |
| 1454 | + accumulate = init |
| 1455 | + for attr_name, attr in iter_outer(self): |
| 1456 | + accumulate = func(attr, accumulate=accumulate, attr_name=attr_name) |
| 1457 | + return accumulate |
| 1458 | + |
| 1459 | + def unwriteable(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: |
| 1460 | + """Whether or not an `AnnData` object can be written to disk for a given store type. |
| 1461 | +
|
| 1462 | + Parameters |
| 1463 | + ---------- |
| 1464 | + store_type |
| 1465 | + Which backing store - `None` indicates that it can be writeable to either. |
| 1466 | +
|
| 1467 | + Returns |
| 1468 | + ------- |
| 1469 | + Whether or not this object is writeable. |
| 1470 | + While the return type may change to include richer output about which elements cannot be written, |
| 1471 | + this new type's evaluation as a boolean will not change from the current behavior i.e., |
| 1472 | + `bool(adata.unwriteable())` will always evaluate the same. |
| 1473 | + """ |
| 1474 | + |
| 1475 | + from anndata._io.specs.registry import _REGISTRY |
| 1476 | + |
| 1477 | + writeable_elems = { |
| 1478 | + src_type |
| 1479 | + for (dest_type, src_type, __) in _REGISTRY.write |
| 1480 | + if store_type is None or store_type in dest_type.__module__ |
| 1481 | + } |
| 1482 | + |
| 1483 | + def predicate( # noqa: PLR0911 |
| 1484 | + elem: RWAble, |
| 1485 | + *, |
| 1486 | + accumulate: bool, |
| 1487 | + attr_name: str | None = None, # TODO: type |
| 1488 | + ): |
| 1489 | + if elem is None: |
| 1490 | + return accumulate |
| 1491 | + if isinstance(elem, AnnData): |
| 1492 | + return accumulate and elem.unwriteable(store_type=store_type) |
| 1493 | + if isinstance(elem, pd.Categorical): |
| 1494 | + return accumulate and predicate(elem.categories, accumulate=accumulate) |
| 1495 | + if isinstance(elem, pd.Series | pd.Index): |
| 1496 | + # matches behavior in methods.py |
| 1497 | + return accumulate and predicate(elem._values, accumulate=accumulate) |
| 1498 | + if isinstance(elem, AwkArray): |
| 1499 | + import awkward as ak |
| 1500 | + |
| 1501 | + container = ak.to_buffers(ak.to_packed(elem)) |
| 1502 | + return accumulate and all( |
| 1503 | + predicate(v, accumulate=accumulate) for v in container[2].values() |
| 1504 | + ) |
| 1505 | + if attr_name == "raw": |
| 1506 | + accumulate = accumulate and type(elem.X) in writeable_elems |
| 1507 | + return accumulate and all( |
| 1508 | + predicate(e[attr], accumulate=accumulate) |
| 1509 | + for e in [elem.var, elem.varm] |
| 1510 | + for attr in e |
| 1511 | + ) |
| 1512 | + if attr_name in { |
| 1513 | + "obs", |
| 1514 | + "obsm", |
| 1515 | + "varm", |
| 1516 | + "var", |
| 1517 | + "layers", |
| 1518 | + "varp", |
| 1519 | + "obsp", |
| 1520 | + "uns", |
| 1521 | + } or isinstance(elem, pd.DataFrame | XDataset | MutableMapping): |
| 1522 | + return accumulate and all( |
| 1523 | + predicate(elem[k], accumulate=accumulate) for k in elem |
| 1524 | + ) |
| 1525 | + return accumulate and type(elem) in writeable_elems |
| 1526 | + |
| 1527 | + return self._reduce(predicate, init=True) |
| 1528 | + |
1439 | 1529 | def var_names_make_unique(self, join: str = "-") -> None: |
1440 | 1530 | # Important to go through the setter so obsm dataframes are updated too |
1441 | 1531 | self.var_names = utils.make_index_unique(self.var.index, join) |
|
0 commit comments