Skip to content

Commit 14679ab

Browse files
aggregate - out of core (#395)
* add csc kernel * add F-continous * add dask test * update aggr * update kernels * fix sums * fix dimensions * update dtype * Update aggregate2 (#404) * test map_blocks * add dense prototype * update kernels * make nicer * fix 64_bit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * update blockwise * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * switch to map_blocks * add test and split * fix tests and mask * adds release note * update dtypes * merge dask * slim down * fix hvg --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 89a1294 commit 14679ab

7 files changed

Lines changed: 364 additions & 74 deletions

File tree

docs/release-notes/0.13.0.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
### 0.13.0 {small}`the-future`
2+
3+
```{rubric} Features
4+
```
5+
* Add support for aggregate operations on CSC matrices, Fortran-ordered arrays, and Dask with sparse CSR and dense matrices {pr}`395` {smaller}`S Dicks`
6+
7+
```{rubric} Performance
8+
```
9+
10+
```{rubric} Bug fixes
11+
```
12+
13+
14+
```{rubric} Misc
15+
```

docs/release-notes/index.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
# Release notes
44

5+
## Version 0.13.0
6+
```{include} /release-notes/0.13.0.md
7+
```
8+
59
## Version 0.12.0
610
```{include} /release-notes/0.12.7.md
711
```

src/rapids_singlecell/get/_aggregated.py

Lines changed: 170 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@
88
)
99

1010
import cupy as cp
11-
import numpy as np
1211
from anndata import AnnData
1312
from cupyx.scipy import sparse as cp_sparse
1413
from scanpy._utils import _resolve_axis
1514
from scanpy.get._aggregated import _combine_categories
1615

16+
from rapids_singlecell._compat import (
17+
DaskArray,
18+
_meta_dense,
19+
)
1720
from rapids_singlecell.get import _check_mask
1821
from rapids_singlecell.preprocessing._utils import _check_gpu_X
1922

2023
if TYPE_CHECKING:
2124
from collections.abc import Collection, Iterable
2225

26+
import numpy as np
2327
import pandas as pd
2428
from numpy.typing import NDArray
2529

@@ -52,7 +56,7 @@ def __init__(
5256
) -> None:
5357
self.mask = mask
5458
self.groupby = cp.array(groupby.codes, dtype=cp.int32)
55-
self.n_cells = cp.array(np.bincount(groupby.codes), dtype=cp.float64).reshape(
59+
self.n_cells = cp.array(cp.bincount(self.groupby), dtype=cp.float64).reshape(
5660
-1, 1
5761
)
5862
self.data = data
@@ -66,24 +70,131 @@ def _get_mask(self):
6670
else:
6771
return cp.ones(self.data.shape[0], dtype=bool)
6872

73+
def count_mean_var_dask(self, dof: int = 1, split_every: int = 2):
74+
"""
75+
This function is used to calculate the sum, mean, and variance of the data matrix.
76+
It automatically detects sparse vs dense matrices and uses the appropriate
77+
CUDA kernel for aggregation.
78+
"""
79+
import dask.array as da
80+
81+
assert dof >= 0
82+
from ._kernels._aggr_kernels import (
83+
_get_aggr_dense_kernel_C,
84+
_get_aggr_sparse_kernel,
85+
)
86+
87+
if isinstance(self.data._meta, cp.ndarray):
88+
kernel = _get_aggr_dense_kernel_C(self.data.dtype)
89+
is_sparse = False
90+
else:
91+
kernel = _get_aggr_sparse_kernel(self.data.dtype)
92+
is_sparse = True
93+
94+
kernel.compile()
95+
n_groups = self.n_cells.shape[0]
96+
97+
def __aggregate_dask(X_part, mask_part, groupby_part):
98+
out = cp.zeros((1, 3, n_groups, self.data.shape[1]), dtype=cp.float64)
99+
threads_per_block = 512
100+
101+
if is_sparse:
102+
# Sparse matrix kernel parameters
103+
grid = (X_part.shape[0],)
104+
kernel_args = (
105+
X_part.indptr,
106+
X_part.indices,
107+
X_part.data,
108+
)
109+
else:
110+
# Dense matrix kernel parameters
111+
N = X_part.shape[0] * X_part.shape[1]
112+
113+
blocks = min(
114+
(N + threads_per_block - 1) // threads_per_block,
115+
cp.cuda.Device().attributes["MultiProcessorCount"] * 8,
116+
)
117+
grid = (blocks,)
118+
kernel_args = (X_part,)
119+
120+
kernel(
121+
grid,
122+
(threads_per_block,),
123+
(
124+
*kernel_args,
125+
out,
126+
groupby_part,
127+
mask_part,
128+
X_part.shape[0],
129+
X_part.shape[1],
130+
n_groups,
131+
),
132+
)
133+
return out
134+
135+
# Prepare Dask arrays
136+
mask = self._get_mask()
137+
mask_dask = da.from_array(
138+
mask, chunks=(self.data.chunks[0]), meta=_meta_dense(mask.dtype)
139+
)
140+
groupby_dask = da.from_array(
141+
self.groupby,
142+
chunks=(self.data.chunks[0]),
143+
meta=_meta_dense(self.groupby.dtype),
144+
)
145+
146+
# Apply aggregation across all blocks
147+
out = da.map_blocks(
148+
__aggregate_dask,
149+
self.data,
150+
mask_dask[..., None],
151+
groupby_dask[..., None],
152+
meta=cp.empty([], dtype=cp.float64),
153+
dtype=cp.float64,
154+
new_axis=(1, 2),
155+
chunks=(
156+
(1,) * self.data.blocks.size,
157+
(3,),
158+
(n_groups,),
159+
(self.data.shape[1],),
160+
),
161+
)
162+
163+
# Compute final aggregated results
164+
out = out.sum(axis=0, split_every=split_every).compute()
165+
sums, counts, sq_sums = out[0], out[1], out[2]
166+
167+
# Calculate statistics
168+
counts = counts.astype(cp.int32)
169+
means = sums / self.n_cells
170+
var = sq_sums / self.n_cells - cp.power(means, 2)
171+
var *= self.n_cells / (self.n_cells - dof)
172+
173+
return {"mean": means, "var": var, "sum": sums, "count_nonzero": counts}
174+
69175
def count_mean_var_sparse(self, dof: int = 1):
70176
"""
71177
This function is used to calculate the sum, mean, and variance of the sparse data matrix.
72178
It uses a custom cuda-kernel to perform the aggregation.
73179
"""
74180

75181
assert dof >= 0
76-
from ._kernels._aggr_kernels import _get_aggr_sparse_kernel
182+
from ._kernels._aggr_kernels import (
183+
_get_aggr_sparse_kernel,
184+
_get_aggr_sparse_kernel_csc,
185+
)
186+
187+
out = cp.zeros(
188+
(3, self.n_cells.shape[0] * self.data.shape[1]), dtype=cp.float64
189+
)
77190

191+
block = (512,)
78192
if self.data.format == "csc":
79-
self.data = self.data.tocsr()
80-
means = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
81-
var = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
82-
sums = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
83-
counts = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.int32)
84-
block = (128,)
85-
grid = (self.data.shape[0],)
86-
aggr_kernel = _get_aggr_sparse_kernel(self.data.dtype)
193+
grid = (self.data.shape[1],)
194+
aggr_kernel = _get_aggr_sparse_kernel_csc(self.data.dtype)
195+
else:
196+
grid = (self.data.shape[0],)
197+
aggr_kernel = _get_aggr_sparse_kernel(self.data.dtype)
87198
mask = self._get_mask()
88199
aggr_kernel(
89200
grid,
@@ -92,23 +203,24 @@ def count_mean_var_sparse(self, dof: int = 1):
92203
self.data.indptr,
93204
self.data.indices,
94205
self.data.data,
95-
counts,
96-
sums,
97-
means,
98-
var,
206+
out,
99207
self.groupby,
100-
self.n_cells,
101208
mask,
102209
self.data.shape[0],
103210
self.data.shape[1],
211+
self.n_cells.shape[0],
104212
),
105213
)
106-
107-
var = var - cp.power(means, 2)
214+
sums, counts, sq_sums = out[0, :], out[1, :], out[2, :]
215+
sums = sums.reshape(self.n_cells.shape[0], self.data.shape[1])
216+
sq_sums = sq_sums.reshape(self.n_cells.shape[0], self.data.shape[1])
217+
counts = counts.reshape(self.n_cells.shape[0], self.data.shape[1])
218+
counts = counts.astype(cp.int32)
219+
means = sums / self.n_cells
220+
var = sq_sums / self.n_cells - means**2
108221
var *= self.n_cells / (self.n_cells - dof)
109222

110223
results = {"sum": sums, "count_nonzero": counts, "mean": means, "var": var}
111-
112224
return results
113225

114226
def count_mean_var_sparse_sparse(self, funcs, dof: int = 1):
@@ -275,34 +387,44 @@ def count_mean_var_dense(self, dof: int = 1):
275387
"""
276388

277389
assert dof >= 0
278-
from ._kernels._aggr_kernels import _get_aggr_dense_kernel
390+
from ._kernels._aggr_kernels import (
391+
_get_aggr_dense_kernel_C,
392+
_get_aggr_dense_kernel_F,
393+
)
279394

280-
means = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
281-
var = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
282-
sums = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
283-
counts = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.int32)
284-
block = (128,)
285-
grid = (self.data.shape[0],)
286-
aggr_kernel = _get_aggr_dense_kernel(self.data.dtype)
395+
out = cp.zeros((3, self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
396+
397+
N = self.data.shape[0] * self.data.shape[1]
398+
threads_per_block = 512
399+
blocks = min(
400+
(N + threads_per_block - 1) // threads_per_block,
401+
cp.cuda.Device().attributes["MultiProcessorCount"] * 8,
402+
)
403+
if self.data.flags.c_contiguous:
404+
aggr_kernel = _get_aggr_dense_kernel_C(self.data.dtype)
405+
else:
406+
aggr_kernel = _get_aggr_dense_kernel_F(self.data.dtype)
287407
mask = self._get_mask()
288408
aggr_kernel(
289-
grid,
290-
block,
409+
(blocks,),
410+
(threads_per_block,),
291411
(
292-
self.data.data,
293-
counts,
294-
sums,
295-
means,
296-
var,
412+
self.data,
413+
out,
297414
self.groupby,
298-
self.n_cells,
299415
mask,
300416
self.data.shape[0],
301417
self.data.shape[1],
418+
self.n_cells.shape[0],
302419
),
303420
)
304-
305-
var = var - cp.power(means, 2)
421+
sums, counts, sq_sums = out[0], out[1], out[2]
422+
sums = sums.reshape(self.n_cells.shape[0], self.data.shape[1])
423+
counts = counts.reshape(self.n_cells.shape[0], self.data.shape[1])
424+
sq_sums = sq_sums.reshape(self.n_cells.shape[0], self.data.shape[1])
425+
counts = counts.astype(cp.int32)
426+
means = sums / self.n_cells
427+
var = sq_sums / self.n_cells - cp.power(means, 2)
306428
var *= self.n_cells / (self.n_cells - dof)
307429

308430
results = {"sum": sums, "count_nonzero": counts, "mean": means, "var": var}
@@ -322,6 +444,7 @@ def aggregate(
322444
obsm: str | None = None,
323445
varm: str | None = None,
324446
return_sparse: bool = False,
447+
**kwargs,
325448
) -> AnnData:
326449
"""\
327450
Aggregate data matrix based on some categorical grouping.
@@ -416,11 +539,9 @@ def aggregate(
416539
elif axis == 1:
417540
# i.e., all of `varm`, `obsm`, `layers` are None so we use `X` which must be transposed
418541
data = data.T
419-
_check_gpu_X(data)
542+
_check_gpu_X(data, allow_dask=True)
420543
dim_df = getattr(adata, axis_name)
421544
categorical, new_label_df = _combine_categories(dim_df, by)
422-
# Actual computation
423-
424545
groupby = Aggregate(groupby=categorical, data=data, mask=mask)
425546

426547
funcs = set([func] if isinstance(func, str) else func)
@@ -429,6 +550,15 @@ def aggregate(
429550

430551
if isinstance(data, cp.ndarray):
431552
result = groupby.count_mean_var_dense(dof)
553+
elif isinstance(data, DaskArray):
554+
if "split_every" in kwargs:
555+
assert isinstance(kwargs["split_every"], int)
556+
assert kwargs["split_every"] > 0
557+
split_every = kwargs["split_every"]
558+
else:
559+
split_every = 2
560+
result = groupby.count_mean_var_dask(dof, split_every=split_every)
561+
432562
else:
433563
if return_sparse:
434564
result = groupby.count_mean_var_sparse_sparse(funcs, dof)

0 commit comments

Comments
 (0)