Skip to content

Commit cb3e6c2

Browse files
ilan-goldpre-commit-ci[bot]flying-sheepmeeseeksmachine
authored
perf: numba based aggregations for sparse data (#4062)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Philipp A. <flying-sheep@web.de> Co-authored-by: Lumberbot (aka Jack) <39504233+meeseeksmachine@users.noreply.github.com>
1 parent 87dc1ec commit cb3e6c2

5 files changed

Lines changed: 194 additions & 37 deletions

File tree

benchmarks/benchmarks/_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ def bmmc(n_obs: int = 400) -> AnnData:
103103

104104
@cache
105105
def _lung93k() -> AnnData:
106-
path = pooch.retrieve(
107-
url="https://figshare.com/ndownloader/files/45788454",
108-
known_hash="md5:4f28af5ff226052443e7e0b39f3f9212",
106+
registry = pooch.create(
107+
path=pooch.os_cache("pooch"),
108+
base_url="doi:10.6084/m9.figshare.25664775.v1/",
109109
)
110+
registry.load_registry_from_doi()
111+
path = registry.fetch("adata.raw_compressed.h5ad")
110112
adata = sc.read_h5ad(path)
111113
assert isinstance(adata.X, CSRBase)
112114
adata.layers["counts"] = adata.X.astype(np.int32, copy=True)

benchmarks/benchmarks/preprocessing_counts.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import anndata as ad
1313

1414
import scanpy as sc
15+
from scanpy._utils import get_literal_vals
16+
from scanpy.get._aggregated import AggType
1517

16-
from ._utils import get_count_dataset
18+
from ._utils import get_count_dataset, get_dataset
1719

1820
if TYPE_CHECKING:
1921
from typing import Any
@@ -146,3 +148,27 @@ def time_log1p(self, *_) -> None:
146148
def peakmem_log1p(self, *_) -> None:
147149
self.adata.uns.pop("log1p", None)
148150
sc.pp.log1p(self.adata)
151+
152+
153+
class Agg: # noqa: D101
154+
params: tuple[AggType] = tuple(get_literal_vals(AggType))
155+
param_names = ("agg_name",)
156+
157+
def setup_cache(self) -> None:
158+
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
159+
adata, _ = get_dataset("lung93k")
160+
adata.write_h5ad("lung93k.h5ad")
161+
162+
def setup(self, agg_name: AggType) -> None:
163+
self.adata = ad.read_h5ad("lung93k.h5ad")
164+
self.agg_name = agg_name
165+
166+
def time_agg(self, *_) -> None:
167+
sc.get.aggregate(
168+
self.adata, by="PatientNumber", func=self.agg_name, layer="counts"
169+
)
170+
171+
def peakmem_agg(self, *_) -> None:
172+
sc.get.aggregate(
173+
self.adata, by="PatientNumber", func=self.agg_name, layer="counts"
174+
)

docs/release-notes/4062.perf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `numba` kernels for mean/var/count-nonzero/sum arregation of sparse data in {func}`scanpy.get.aggregate` {smaller}`I Gold`

src/scanpy/get/_aggregated.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import numpy as np
77
import pandas as pd
8-
from anndata import AnnData, utils
8+
from anndata import AnnData
99
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
1010
from scipy import sparse
1111
from sklearn.utils.sparsefuncs import csc_median_axis_0
1212

1313
from scanpy._compat import CSBase, CSRBase, DaskArray
1414

1515
from .._utils import _resolve_axis, get_literal_vals
16+
from ._kernels import agg_sum_csc, agg_sum_csr, mean_var_csc, mean_var_csr
1617
from .get import _check_mask
1718

1819
if TYPE_CHECKING:
@@ -25,7 +26,7 @@
2526
type AggType = ConstantDtypeAgg | Literal["mean", "var"]
2627

2728

28-
class Aggregate:
29+
class Aggregate[ArrayT: np.ndarray | CSBase]:
2930
"""Functionality for generic grouping and aggregating.
3031
3132
There is currently support for count_nonzero, sum, mean, and variance.
@@ -53,19 +54,22 @@ class Aggregate:
5354
def __init__(
5455
self,
5556
groupby: pd.Categorical,
56-
data: Array,
57+
data: ArrayT,
5758
*,
5859
mask: NDArray[np.bool] | None = None,
5960
) -> None:
6061
self.groupby = groupby
6162
if (missing := groupby.isna()).any():
6263
mask = mask & ~missing if mask is not None else ~missing
6364
self.indicator_matrix = sparse_indicator(groupby, mask=mask)
65+
if isinstance(data, CSBase):
66+
# TODO: Look into if this can be CSR and fast for dense
67+
self.indicator_matrix = self.indicator_matrix.tocsr()
6468
self.data = data
6569

6670
groupby: pd.Categorical
67-
indicator_matrix: sparse.coo_matrix
68-
data: Array
71+
indicator_matrix: CSRBase | sparse.coo_array
72+
data: ArrayT
6973

7074
def count_nonzero(self) -> NDArray[np.integer]:
7175
"""Count the number of observations in each group.
@@ -75,19 +79,30 @@ def count_nonzero(self) -> NDArray[np.integer]:
7579
Array of counts.
7680
7781
"""
78-
# pattern = self.data._with_data(np.broadcast_to(1, len(self.data.data)))
79-
# return self.indicator_matrix @ pattern
80-
return utils.asarray(self.indicator_matrix @ (self.data != 0))
82+
return self._sum(data=(self.data != 0).astype("uint8"))
83+
84+
def _sum(self, data: ArrayT):
85+
if isinstance(data, np.ndarray):
86+
res = self.indicator_matrix @ data
87+
if isinstance(res, CSBase):
88+
return res.toarray()
89+
return res
90+
dtype = np.int64 if np.issubdtype(data.dtype, np.integer) else np.float64
91+
out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype=dtype)
92+
(agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)(
93+
self.indicator_matrix, data, out
94+
)
95+
return out
8196

82-
def sum(self) -> Array:
97+
def sum(self) -> np.ndarray:
8398
"""Compute the sum per feature per group of observations.
8499
85100
Returns
86101
-------
87102
Array of sum.
88103
89104
"""
90-
return utils.asarray(self.indicator_matrix @ self.data)
105+
return self._sum(self.data)
91106

92107
def mean(self) -> Array:
93108
"""Compute the mean per feature per group of observations.
@@ -97,10 +112,7 @@ def mean(self) -> Array:
97112
Array of mean.
98113
99114
"""
100-
return (
101-
utils.asarray(self.indicator_matrix @ self.data)
102-
/ np.bincount(self.groupby.codes)[:, None]
103-
)
115+
return self.sum() / np.bincount(self.groupby.codes)[:, None]
104116

105117
def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
106118
"""Compute the count, as well as mean and variance per feature, per group of observations.
@@ -124,14 +136,17 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
124136
assert dof >= 0
125137

126138
group_counts = np.bincount(self.groupby.codes)
127-
mean_ = self.mean()
128-
# sparse matrices do not support ** for elementwise power.
129-
mean_sq = (
130-
utils.asarray(self.indicator_matrix @ _power(self.data, 2))
131-
/ group_counts[:, None]
132-
)
133-
sq_mean = mean_**2
134-
var_ = mean_sq - sq_mean
139+
if isinstance(self.data, np.ndarray):
140+
mean_ = self.mean()
141+
# sparse matrices do not support ** for elementwise power.
142+
mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None]
143+
sq_mean = mean_**2
144+
var_ = mean_sq - sq_mean
145+
else:
146+
mean_, var_ = (
147+
mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc
148+
)(self.indicator_matrix, self.data)
149+
sq_mean = mean_**2
135150
# TODO: Why these values exactly? Because they are high relative to the datatype?
136151
# (unchanged from original code: https://github.com/scverse/anndata/pull/564)
137152
precision = 2 << (42 if self.data.dtype == np.float64 else 20)
@@ -550,18 +565,15 @@ def sparse_indicator(
550565
categorical: pd.Categorical,
551566
*,
552567
mask: NDArray[np.bool] | None = None,
553-
weight: NDArray[np.floating] | None = None,
554-
) -> sparse.coo_matrix:
555-
if mask is not None and weight is None:
556-
weight = mask.astype(np.float32)
557-
elif mask is not None and weight is not None:
558-
weight = mask * weight
559-
elif mask is None and weight is None:
560-
weight = np.broadcast_to(1.0, len(categorical))
568+
) -> sparse.coo_array:
569+
# TODO: why is this float64. This is a scanpy 2.0 problem maybe?
570+
mask = (
571+
np.broadcast_to(1.0, len(categorical)) if mask is None else mask.astype("uint8")
572+
)
561573
# can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked
562-
codes = categorical.codes if mask is None else np.where(mask, categorical.codes, 0)
563-
a = sparse.coo_matrix(
564-
(weight, (codes, np.arange(len(categorical)))),
574+
codes = np.where(mask, categorical.codes, 0)
575+
a = sparse.coo_array(
576+
(mask, (codes, np.arange(len(categorical)))),
565577
shape=(len(categorical.categories), len(categorical)),
566578
)
567579
return a

src/scanpy/get/_kernels.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import numba
6+
import numpy as np
7+
from fast_array_utils.numba import njit
8+
9+
if TYPE_CHECKING:
10+
from numpy.typing import NDArray
11+
12+
from .._compat import CSCBase, CSRBase
13+
14+
15+
@njit
16+
def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray) -> None:
17+
for cat_num in numba.prange(indicator.shape[0]):
18+
start_cat_idx = indicator.indptr[cat_num]
19+
stop_cat_idx = indicator.indptr[cat_num + 1]
20+
for row_num in range(start_cat_idx, stop_cat_idx):
21+
obs_per_cat = indicator.indices[row_num]
22+
23+
start_obs = data.indptr[obs_per_cat]
24+
end_obs = data.indptr[obs_per_cat + 1]
25+
26+
for j in range(start_obs, end_obs):
27+
col = data.indices[j]
28+
out[cat_num, col] += data.data[j]
29+
30+
31+
@njit
32+
def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray) -> None:
33+
obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)
34+
35+
for cat in range(indicator.shape[0]):
36+
for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
37+
obs_to_cat[indicator.indices[k]] = cat
38+
39+
for col in numba.prange(data.shape[1]):
40+
start = data.indptr[col]
41+
end = data.indptr[col + 1]
42+
43+
for j in range(start, end):
44+
obs = data.indices[j]
45+
cat = obs_to_cat[obs]
46+
47+
if cat != -1:
48+
out[cat, col] += data.data[j]
49+
50+
51+
@njit
52+
def mean_var_csr(
53+
indicator: CSRBase,
54+
data: CSCBase,
55+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
56+
mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
57+
var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
58+
59+
for cat_num in numba.prange(indicator.shape[0]):
60+
start_cat_idx = indicator.indptr[cat_num]
61+
stop_cat_idx = indicator.indptr[cat_num + 1]
62+
for row_num in range(start_cat_idx, stop_cat_idx):
63+
obs_per_cat = indicator.indices[row_num]
64+
65+
start_obs = data.indptr[obs_per_cat]
66+
end_obs = data.indptr[obs_per_cat + 1]
67+
68+
for j in range(start_obs, end_obs):
69+
col = data.indices[j]
70+
value = np.float64(data.data[j])
71+
value = data.data[j]
72+
mean[cat_num, col] += value
73+
var[cat_num, col] += value * value
74+
75+
n_obs = stop_cat_idx - start_cat_idx
76+
mean_cat = mean[cat_num, :] / n_obs
77+
mean[cat_num, :] = mean_cat
78+
var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat)
79+
return mean, var
80+
81+
82+
@njit
83+
def mean_var_csc(
84+
indicator: CSRBase, data: CSCBase
85+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
86+
obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)
87+
88+
mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
89+
var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
90+
91+
for cat in range(indicator.shape[0]):
92+
for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
93+
obs_to_cat[indicator.indices[k]] = cat
94+
95+
for col in numba.prange(data.shape[1]):
96+
start = data.indptr[col]
97+
end = data.indptr[col + 1]
98+
99+
for j in range(start, end):
100+
obs = data.indices[j]
101+
cat = obs_to_cat[obs]
102+
103+
if cat != -1:
104+
value = np.float64(data.data[j])
105+
value = data.data[j]
106+
mean[cat, col] += value
107+
var[cat, col] += value * value
108+
109+
for cat_num in numba.prange(indicator.shape[0]):
110+
start_cat_idx = indicator.indptr[cat_num]
111+
stop_cat_idx = indicator.indptr[cat_num + 1]
112+
n_obs = stop_cat_idx - start_cat_idx
113+
mean_cat = mean[cat_num, :] / n_obs
114+
mean[cat_num, :] = mean_cat
115+
var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat)
116+
return mean, var

0 commit comments

Comments
 (0)