Skip to content

Commit d6d55ac

Browse files
authored
add score_genes support for Dask (#408)
* add score_genes * clean up X_to _GPU
1 parent 06854bf commit d6d55ac

7 files changed

Lines changed: 250 additions & 24 deletions

File tree

docs/release-notes/0.13.0.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
```{rubric} Features
44
```
55
* Add support for aggregate operations on CSC matrices, Fortran-ordered arrays, and Dask with sparse CSR and dense matrices {pr}`395` {smaller}`S Dicks`
6+
* Adds dask support for `tl.score_genes` & `tl.score_genes_cell_cycle` {pr}`408` {smaller}`S Dicks`
67

78
```{rubric} Performance
89
```
910

1011
```{rubric} Bug fixes
1112
```
12-
13+
* Fixes a bug for `_get_mean_var` with dask chunk sizes {pr}`408` {smaller}`S Dicks`
1314

1415
```{rubric} Misc
1516
```

src/rapids_singlecell/get/_anndata.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,9 @@ def X_to_GPU(
9191
if isinstance(X, GPU_ARRAY_TYPE):
9292
pass
9393
elif isinstance(X, DaskArray):
94-
if isinstance(X._meta, csc_matrix_cpu):
95-
pass
96-
meta = _meta_sparse if isinstance(X._meta, csr_matrix_cpu) else _meta_dense
97-
X = X.map_blocks(X_to_GPU, meta=meta(X.dtype))
94+
if isinstance(X._meta, csr_matrix_cpu | np.ndarray):
95+
meta = _meta_sparse if isinstance(X._meta, csr_matrix_cpu) else _meta_dense
96+
X = X.map_blocks(X_to_GPU, meta=meta(X.dtype))
9897
elif isspmatrix_csr_cpu(X):
9998
X = csr_matrix_gpu(X)
10099
elif isspmatrix_csc_cpu(X):

src/rapids_singlecell/preprocessing/_utils.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import cupy as cp
77
import numpy as np
88
import pandas as pd
9-
from cuml.internals.memory_utils import with_cupy_rmm
109
from cupyx.scipy.sparse import issparse, isspmatrix_csc, isspmatrix_csr, spmatrix
1110
from natsort import natsorted
1211
from pandas.api.types import infer_dtype
@@ -98,7 +97,6 @@ def _mean_var_minor(X, major, minor):
9897
return mean, var
9998

10099

101-
@with_cupy_rmm
102100
def _mean_var_minor_dask(X, major, minor):
103101
"""
104102
Implements sum operation for dask array when the backend is cupy sparse csr matrix
@@ -134,7 +132,6 @@ def __mean_var(X_part):
134132

135133

136134
# todo: Implement this dynamically for csc matrix as well
137-
@with_cupy_rmm
138135
def _mean_var_major_dask(X, major, minor):
139136
"""
140137
Implements sum operation for dask array when the backend is cupy sparse csr matrix
@@ -165,23 +162,23 @@ def __mean_var(X_part):
165162
minor,
166163
),
167164
)
168-
return cp.vstack([mean, var])
165+
return cp.stack([mean, var], axis=1)
169166

170-
mean, var = X.map_blocks(
167+
output = X.map_blocks(
171168
__mean_var,
172-
chunks=((2,), X.chunks[0]),
169+
chunks=(X.chunks[0], (2,)),
173170
dtype=cp.float64,
174171
meta=cp.array([]),
175172
)
176-
173+
mean = output[:, 0]
174+
var = output[:, 1]
177175
mean = mean / minor
178176
var = var / minor
179-
var -= cp.power(mean, 2)
177+
var -= mean**2
180178
var *= minor / (minor - 1)
181179
return mean, var
182180

183181

184-
@with_cupy_rmm
185182
def _mean_var_dense_dask(X, axis):
186183
"""
187184
Implements sum operation for dask array when the backend is cupy dense matrix
@@ -192,25 +189,24 @@ def __mean_var(X_part):
192189
var = sq_sum(X_part, axis=axis)
193190
mean = mean_sum(X_part, axis=axis)
194191
if axis == 0:
195-
mean = mean.reshape(-1, 1)
196-
var = var.reshape(-1, 1)
197-
return cp.vstack([mean.ravel(), var.ravel()])[
198-
None if 1 - axis else slice(None, None), ...
199-
]
192+
return cp.vstack([mean, var])[None, ...]
193+
else:
194+
return cp.stack([mean, var], axis=1)
200195

201196
n_blocks = X.blocks.size
202197
mean_var = X.map_blocks(
203198
__mean_var,
204199
new_axis=(1,) if axis - 1 else None,
205-
chunks=((2,), X.chunks[0]) if axis else ((1,) * n_blocks, (2,), (X.shape[1],)),
200+
chunks=(X.chunks[0], (2,)) if axis else ((1,) * n_blocks, (2,), (X.shape[1],)),
206201
dtype=cp.float64,
207202
meta=cp.array([]),
208203
)
209204

210205
if axis == 0:
211206
mean, var = mean_var.sum(axis=0)
212207
else:
213-
mean, var = mean_var
208+
mean = mean_var[:, 0]
209+
var = mean_var[:, 1]
214210

215211
mean = mean / X.shape[axis]
216212
var = var / X.shape[axis]

src/rapids_singlecell/tools/_score_genes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from typing import TYPE_CHECKING
55

66
import cupy as cp
7+
import dask
78
import numpy as np
89
import pandas as pd
910

11+
from rapids_singlecell._compat import DaskArray
1012
from rapids_singlecell.get import X_to_GPU, _get_obs_rep
1113
from rapids_singlecell.preprocessing._utils import _check_gpu_X, _check_use_raw
1214

@@ -77,8 +79,7 @@ def score_genes(
7779
use_raw = _check_use_raw(adata, use_raw, layer=layer)
7880
X = _get_obs_rep(adata, layer=layer, use_raw=use_raw)
7981
X = X_to_GPU(X)
80-
_check_gpu_X(X)
81-
82+
_check_gpu_X(X, allow_dask=True)
8283
if random_state is not None:
8384
np.random.seed(random_state)
8485

@@ -108,6 +109,8 @@ def score_genes(
108109
means_control = _nan_mean(
109110
X, axis=1, mask=control_array, n_features=len(control_genes)
110111
)
112+
if isinstance(X, DaskArray):
113+
means_list, means_control = dask.compute(means_list, means_control)
111114

112115
score = means_list - means_control
113116

@@ -157,7 +160,10 @@ def _score_genes_bins(
157160
) -> Generator[pd.Index[str], None, None]:
158161
# average expression of genes
159162
idx = cp.array(var_names.isin(gene_pool), dtype=cp.bool_)
160-
nanmeans = _nan_mean(X, axis=0, mask=idx, n_features=len(gene_pool)).get()
163+
nanmeans = _nan_mean(X, axis=0, mask=idx, n_features=len(gene_pool))
164+
if isinstance(X, DaskArray):
165+
nanmeans = nanmeans.compute()
166+
nanmeans = nanmeans.get()
161167
obs_avg = pd.Series(nanmeans, index=gene_pool)
162168
# Sometimes (and I don’t know how) missing data may be there, with NaNs for missing entries
163169
obs_avg = obs_avg[np.isfinite(obs_avg)]

src/rapids_singlecell/tools/_utils.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import cupy as cp
66
from cupyx.scipy.sparse import issparse, isspmatrix_csc, isspmatrix_csr
77

8+
from rapids_singlecell._compat import DaskArray
9+
810
from . import pca
911

1012

@@ -47,6 +49,100 @@ def _choose_representation(adata, use_rep=None, n_pcs=None):
4749
return X
4850

4951

52+
def _nan_mean_minor_dask_sparse(X, major, minor, *, mask=None, n_features=None):
53+
from ._kernels._nan_mean_kernels import _get_nan_mean_minor
54+
55+
kernel = _get_nan_mean_minor(X.dtype)
56+
kernel.compile()
57+
58+
def __nan_mean_minor(X_part):
59+
mean = cp.zeros(minor, dtype=cp.float64)
60+
nans = cp.zeros(minor, dtype=cp.int32)
61+
tpb = (32,)
62+
bpg_x = math.ceil(X_part.nnz / 32)
63+
bpg = (bpg_x,)
64+
kernel(bpg, tpb, (X_part.indices, X_part.data, mean, nans, mask, X_part.nnz))
65+
return cp.vstack([mean, nans.astype(cp.float64)])[None, ...]
66+
67+
n_blocks = X.blocks.size
68+
mean, nans = X.map_blocks(
69+
__nan_mean_minor,
70+
new_axis=(1,),
71+
chunks=((1,) * n_blocks, (2,), (minor,)),
72+
dtype=cp.float64,
73+
meta=cp.array([]),
74+
).sum(axis=0)
75+
mean /= n_features - nans
76+
return mean
77+
78+
79+
def _nan_mean_major_dask_sparse(X, major, minor, *, mask=None, n_features=None):
80+
from ._kernels._nan_mean_kernels import _get_nan_mean_major
81+
82+
kernel = _get_nan_mean_major(X.dtype)
83+
kernel.compile()
84+
85+
def __nan_mean_major(X_part):
86+
major_part = X_part.shape[0]
87+
mean = cp.zeros(major_part, dtype=cp.float64)
88+
nans = cp.zeros(major_part, dtype=cp.int32)
89+
block = (64,)
90+
grid = (major_part,)
91+
kernel(
92+
grid,
93+
block,
94+
(
95+
X_part.indptr,
96+
X_part.indices,
97+
X_part.data,
98+
mean,
99+
nans,
100+
mask,
101+
major_part,
102+
minor,
103+
),
104+
)
105+
return cp.stack([mean, nans.astype(cp.float64)], axis=1)
106+
107+
output = X.map_blocks(
108+
__nan_mean_major,
109+
chunks=(X.chunks[0], (2,)),
110+
dtype=cp.float64,
111+
meta=cp.array([]),
112+
)
113+
mean = output[:, 0]
114+
nans = output[:, 1]
115+
mean /= n_features - nans
116+
return mean
117+
118+
119+
def _nan_mean_dense_dask(X, axis, *, mask, n_features):
120+
def __nan_mean_dense(X_part):
121+
X_to_use = X_part[:, mask].astype(cp.float64)
122+
sum = cp.nansum(X_to_use, axis=axis).ravel()
123+
nans = cp.sum(cp.isnan(X_to_use), axis=axis).ravel()
124+
if axis == 1:
125+
return cp.stack([sum, nans.astype(cp.float64)], axis=1)
126+
else:
127+
return cp.vstack([sum, nans.astype(cp.float64)])[None, ...]
128+
129+
n_blocks = X.blocks.size
130+
output = X.map_blocks(
131+
__nan_mean_dense,
132+
new_axis=(1,) if axis - 1 else None,
133+
chunks=(X.chunks[0], (2,)) if axis else ((1,) * n_blocks, (2,), (X.shape[1],)),
134+
dtype=cp.float64,
135+
meta=cp.array([]),
136+
)
137+
if axis == 0:
138+
mean, nans = output.sum(axis=0)
139+
else:
140+
mean = output[:, 0]
141+
nans = output[:, 1]
142+
mean /= n_features - nans
143+
return mean
144+
145+
50146
def _nan_mean_minor(X, major, minor, *, mask=None, n_features=None):
51147
from ._kernels._nan_mean_kernels import _get_nan_mean_minor
52148

@@ -120,6 +216,34 @@ def _nan_mean(X, axis=0, *, mask=None, n_features=None):
120216
mean = _nan_mean_minor(
121217
X, major, minor, mask=mask, n_features=n_features
122218
)
219+
elif isinstance(X, DaskArray):
220+
if isspmatrix_csr(X._meta):
221+
major, minor = X.shape
222+
if mask is None:
223+
mask = cp.ones(X.shape[1], dtype=cp.bool_)
224+
if axis == 0:
225+
n_features = major
226+
mean = _nan_mean_minor_dask_sparse(
227+
X, major, minor, mask=mask, n_features=n_features
228+
)
229+
elif axis == 1:
230+
n_features = minor if n_features is None else n_features
231+
mean = _nan_mean_major_dask_sparse(
232+
X, major, minor, mask=mask, n_features=n_features
233+
)
234+
else:
235+
raise ValueError("axis must be either 0 or 1")
236+
elif isinstance(X._meta, cp.ndarray):
237+
if mask is None:
238+
mask = cp.ones(X.shape[1], dtype=cp.bool_)
239+
if n_features is None:
240+
n_features = X.shape[axis]
241+
mean = _nan_mean_dense_dask(X, axis, mask=mask, n_features=n_features)
242+
# raise NotImplementedError("Dask dense arrays are not supported yet")
243+
else:
244+
raise ValueError(
245+
"Type not supported. Please provide a CuPy ndarray or a CuPy sparse matrix. Or a Dask array with a CuPy ndarray or a CuPy sparse matrix as meta."
246+
)
123247
else:
124248
if mask is None:
125249
mask = cp.ones(X.shape[1], dtype=cp.bool_)

tests/dask/test_dask_mean_var.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
import cupy as cp
4+
import pytest
5+
from scanpy.datasets import pbmc3k, pbmc68k_reduced
6+
7+
import rapids_singlecell as rsc
8+
from rapids_singlecell._testing import (
9+
as_dense_cupy_dask_array,
10+
as_sparse_cupy_dask_array,
11+
)
12+
from rapids_singlecell.preprocessing._utils import _get_mean_var
13+
14+
from ..test_score_genes import _create_sparse_nan_matrix # noqa: TID252
15+
16+
17+
@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
18+
@pytest.mark.parametrize("axis", [0, 1])
19+
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
20+
def test_mean_var(client, data_kind, axis, dtype):
21+
if data_kind == "dense":
22+
adata = pbmc68k_reduced()
23+
adata.X = adata.X.astype(dtype)
24+
dask_data = adata.copy()
25+
dask_data.X = as_dense_cupy_dask_array(dask_data.X).persist()
26+
rsc.get.anndata_to_GPU(adata)
27+
elif data_kind == "sparse":
28+
adata = pbmc3k()
29+
adata.X = adata.X.astype(dtype)
30+
dask_data = adata.copy()
31+
dask_data.X = as_sparse_cupy_dask_array(dask_data.X).persist()
32+
rsc.get.anndata_to_GPU(adata)
33+
34+
mean, var = _get_mean_var(adata.X, axis=axis)
35+
dask_mean, dask_var = _get_mean_var(dask_data.X, axis=axis)
36+
dask_mean, dask_var = dask_mean.compute(), dask_var.compute()
37+
38+
cp.testing.assert_allclose(mean, dask_mean)
39+
cp.testing.assert_allclose(var, dask_var)
40+
41+
42+
@pytest.mark.parametrize("array_type", ["csr", "dense"])
43+
@pytest.mark.parametrize("percent_nan", [0, 0.3])
44+
def test_sparse_nanmean(client, array_type, percent_nan):
45+
"""Needs to be fixed"""
46+
from rapids_singlecell.tools._utils import _nan_mean
47+
48+
R, C = 100, 50
49+
50+
# sparse matrix with nan
51+
S = _create_sparse_nan_matrix(R, C, percent_zero=0.3, percent_nan=percent_nan)
52+
S = S.astype(cp.float64)
53+
A = S.toarray()
54+
A = rsc.get.X_to_GPU(A)
55+
56+
if array_type == "dense":
57+
S = as_dense_cupy_dask_array(A).persist()
58+
else:
59+
S = as_sparse_cupy_dask_array(S).persist()
60+
61+
cp.testing.assert_allclose(
62+
_nan_mean(A, 1).ravel(), (_nan_mean(S, 1)).ravel().compute()
63+
)
64+
cp.testing.assert_allclose(
65+
_nan_mean(A, 0).ravel(), (_nan_mean(S, 0)).ravel().compute()
66+
)

0 commit comments

Comments
 (0)