Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/release-notes/0.13.0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
### 0.13.0 {small}`the-future`

```{rubric} Features
```
* Add support for aggregate operations on CSC matrices, Fortran-ordered arrays, and Dask with sparse CSR and dense matrices {pr}`395` {smaller}`S Dicks`

```{rubric} Performance
```

```{rubric} Bug fixes
```


```{rubric} Misc
```
4 changes: 4 additions & 0 deletions docs/release-notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

# Release notes

## Version 0.13.0
```{include} /release-notes/0.13.0.md
```

## Version 0.12.0
```{include} /release-notes/0.12.7.md
```
Expand Down
210 changes: 170 additions & 40 deletions src/rapids_singlecell/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
)

import cupy as cp
import numpy as np
from anndata import AnnData
from cupyx.scipy import sparse as cp_sparse
from scanpy._utils import _resolve_axis
from scanpy.get._aggregated import _combine_categories

from rapids_singlecell._compat import (
DaskArray,
_meta_dense,
)
from rapids_singlecell.get import _check_mask
from rapids_singlecell.preprocessing._utils import _check_gpu_X

if TYPE_CHECKING:
from collections.abc import Collection, Iterable

import numpy as np
import pandas as pd
from numpy.typing import NDArray

Expand Down Expand Up @@ -52,7 +56,7 @@ def __init__(
) -> None:
self.mask = mask
self.groupby = cp.array(groupby.codes, dtype=cp.int32)
self.n_cells = cp.array(np.bincount(groupby.codes), dtype=cp.float64).reshape(
self.n_cells = cp.array(cp.bincount(self.groupby), dtype=cp.float64).reshape(
-1, 1
)
self.data = data
Expand All @@ -66,24 +70,131 @@ def _get_mask(self):
else:
return cp.ones(self.data.shape[0], dtype=bool)

def count_mean_var_dask(self, dof: int = 1, split_every: int = 2):
"""
This function is used to calculate the sum, mean, and variance of the data matrix.
It automatically detects sparse vs dense matrices and uses the appropriate
CUDA kernel for aggregation.
"""
import dask.array as da

assert dof >= 0
from ._kernels._aggr_kernels import (
_get_aggr_dense_kernel_C,
_get_aggr_sparse_kernel,
)

if isinstance(self.data._meta, cp.ndarray):
kernel = _get_aggr_dense_kernel_C(self.data.dtype)
is_sparse = False
else:
kernel = _get_aggr_sparse_kernel(self.data.dtype)
is_sparse = True

kernel.compile()
n_groups = self.n_cells.shape[0]

def __aggregate_dask(X_part, mask_part, groupby_part):
out = cp.zeros((1, 3, n_groups, self.data.shape[1]), dtype=cp.float64)
threads_per_block = 512

if is_sparse:
# Sparse matrix kernel parameters
grid = (X_part.shape[0],)
kernel_args = (
X_part.indptr,
X_part.indices,
X_part.data,
)
else:
# Dense matrix kernel parameters
N = X_part.shape[0] * X_part.shape[1]

blocks = min(
(N + threads_per_block - 1) // threads_per_block,
cp.cuda.Device().attributes["MultiProcessorCount"] * 8,
)
grid = (blocks,)
kernel_args = (X_part,)

kernel(
grid,
(threads_per_block,),
(
*kernel_args,
out,
groupby_part,
mask_part,
X_part.shape[0],
X_part.shape[1],
n_groups,
),
)
return out

# Prepare Dask arrays
mask = self._get_mask()
mask_dask = da.from_array(
mask, chunks=(self.data.chunks[0]), meta=_meta_dense(mask.dtype)
)
groupby_dask = da.from_array(
self.groupby,
chunks=(self.data.chunks[0]),
meta=_meta_dense(self.groupby.dtype),
)

# Apply aggregation across all blocks
out = da.map_blocks(
__aggregate_dask,
self.data,
mask_dask[..., None],
groupby_dask[..., None],
meta=cp.empty([], dtype=cp.float64),
dtype=cp.float64,
new_axis=(1, 2),
chunks=(
(1,) * self.data.blocks.size,
(3,),
(n_groups,),
(self.data.shape[1],),
),
)

# Compute final aggregated results
out = out.sum(axis=0, split_every=split_every).compute()
sums, counts, sq_sums = out[0], out[1], out[2]

# Calculate statistics
counts = counts.astype(cp.int32)
means = sums / self.n_cells
var = sq_sums / self.n_cells - cp.power(means, 2)
var *= self.n_cells / (self.n_cells - dof)

return {"mean": means, "var": var, "sum": sums, "count_nonzero": counts}

def count_mean_var_sparse(self, dof: int = 1):
"""
This function is used to calculate the sum, mean, and variance of the sparse data matrix.
It uses a custom cuda-kernel to perform the aggregation.
"""

assert dof >= 0
from ._kernels._aggr_kernels import _get_aggr_sparse_kernel
from ._kernels._aggr_kernels import (
_get_aggr_sparse_kernel,
_get_aggr_sparse_kernel_csc,
)

out = cp.zeros(
(3, self.n_cells.shape[0] * self.data.shape[1]), dtype=cp.float64
)

block = (512,)
if self.data.format == "csc":
self.data = self.data.tocsr()
means = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
var = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
sums = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
counts = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.int32)
block = (128,)
grid = (self.data.shape[0],)
aggr_kernel = _get_aggr_sparse_kernel(self.data.dtype)
grid = (self.data.shape[1],)
aggr_kernel = _get_aggr_sparse_kernel_csc(self.data.dtype)
else:
grid = (self.data.shape[0],)
aggr_kernel = _get_aggr_sparse_kernel(self.data.dtype)
mask = self._get_mask()
aggr_kernel(
grid,
Expand All @@ -92,23 +203,24 @@ def count_mean_var_sparse(self, dof: int = 1):
self.data.indptr,
self.data.indices,
self.data.data,
counts,
sums,
means,
var,
out,
self.groupby,
self.n_cells,
mask,
self.data.shape[0],
self.data.shape[1],
self.n_cells.shape[0],
),
)

var = var - cp.power(means, 2)
sums, counts, sq_sums = out[0, :], out[1, :], out[2, :]
sums = sums.reshape(self.n_cells.shape[0], self.data.shape[1])
sq_sums = sq_sums.reshape(self.n_cells.shape[0], self.data.shape[1])
counts = counts.reshape(self.n_cells.shape[0], self.data.shape[1])
counts = counts.astype(cp.int32)
means = sums / self.n_cells
var = sq_sums / self.n_cells - means**2
var *= self.n_cells / (self.n_cells - dof)

results = {"sum": sums, "count_nonzero": counts, "mean": means, "var": var}

return results

def count_mean_var_sparse_sparse(self, funcs, dof: int = 1):
Expand Down Expand Up @@ -275,34 +387,44 @@ def count_mean_var_dense(self, dof: int = 1):
"""

assert dof >= 0
from ._kernels._aggr_kernels import _get_aggr_dense_kernel
from ._kernels._aggr_kernels import (
_get_aggr_dense_kernel_C,
_get_aggr_dense_kernel_F,
)

means = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
var = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
sums = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)
counts = cp.zeros((self.n_cells.shape[0], self.data.shape[1]), dtype=cp.int32)
block = (128,)
grid = (self.data.shape[0],)
aggr_kernel = _get_aggr_dense_kernel(self.data.dtype)
out = cp.zeros((3, self.n_cells.shape[0], self.data.shape[1]), dtype=cp.float64)

N = self.data.shape[0] * self.data.shape[1]
threads_per_block = 512
blocks = min(
(N + threads_per_block - 1) // threads_per_block,
cp.cuda.Device().attributes["MultiProcessorCount"] * 8,
)
if self.data.flags.c_contiguous:
aggr_kernel = _get_aggr_dense_kernel_C(self.data.dtype)
else:
aggr_kernel = _get_aggr_dense_kernel_F(self.data.dtype)
mask = self._get_mask()
aggr_kernel(
grid,
block,
(blocks,),
(threads_per_block,),
(
self.data.data,
counts,
sums,
means,
var,
self.data,
out,
self.groupby,
self.n_cells,
mask,
self.data.shape[0],
self.data.shape[1],
self.n_cells.shape[0],
),
)

var = var - cp.power(means, 2)
sums, counts, sq_sums = out[0], out[1], out[2]
sums = sums.reshape(self.n_cells.shape[0], self.data.shape[1])
counts = counts.reshape(self.n_cells.shape[0], self.data.shape[1])
sq_sums = sq_sums.reshape(self.n_cells.shape[0], self.data.shape[1])
counts = counts.astype(cp.int32)
means = sums / self.n_cells
var = sq_sums / self.n_cells - cp.power(means, 2)
var *= self.n_cells / (self.n_cells - dof)

results = {"sum": sums, "count_nonzero": counts, "mean": means, "var": var}
Expand All @@ -322,6 +444,7 @@ def aggregate(
obsm: str | None = None,
varm: str | None = None,
return_sparse: bool = False,
**kwargs,
) -> AnnData:
"""\
Aggregate data matrix based on some categorical grouping.
Expand Down Expand Up @@ -416,11 +539,9 @@ def aggregate(
elif axis == 1:
# i.e., all of `varm`, `obsm`, `layers` are None so we use `X` which must be transposed
data = data.T
_check_gpu_X(data)
_check_gpu_X(data, allow_dask=True)
dim_df = getattr(adata, axis_name)
categorical, new_label_df = _combine_categories(dim_df, by)
# Actual computation

groupby = Aggregate(groupby=categorical, data=data, mask=mask)

funcs = set([func] if isinstance(func, str) else func)
Expand All @@ -429,6 +550,15 @@ def aggregate(

if isinstance(data, cp.ndarray):
result = groupby.count_mean_var_dense(dof)
elif isinstance(data, DaskArray):
if "split_every" in kwargs:
assert isinstance(kwargs["split_every"], int)
assert kwargs["split_every"] > 0
split_every = kwargs["split_every"]
else:
split_every = 2
result = groupby.count_mean_var_dask(dof, split_every=split_every)

else:
if return_sparse:
result = groupby.count_mean_var_sparse_sparse(funcs, dof)
Expand Down
Loading
Loading