|
5 | 5 | import cupy as cp |
6 | 6 | from cupyx.scipy.sparse import issparse, isspmatrix_csc, isspmatrix_csr |
7 | 7 |
|
| 8 | +from rapids_singlecell._compat import DaskArray |
| 9 | + |
8 | 10 | from . import pca |
9 | 11 |
|
10 | 12 |
|
@@ -47,6 +49,100 @@ def _choose_representation(adata, use_rep=None, n_pcs=None): |
47 | 49 | return X |
48 | 50 |
|
49 | 51 |
|
| 52 | +def _nan_mean_minor_dask_sparse(X, major, minor, *, mask=None, n_features=None): |
| 53 | + from ._kernels._nan_mean_kernels import _get_nan_mean_minor |
| 54 | + |
| 55 | + kernel = _get_nan_mean_minor(X.dtype) |
| 56 | + kernel.compile() |
| 57 | + |
| 58 | + def __nan_mean_minor(X_part): |
| 59 | + mean = cp.zeros(minor, dtype=cp.float64) |
| 60 | + nans = cp.zeros(minor, dtype=cp.int32) |
| 61 | + tpb = (32,) |
| 62 | + bpg_x = math.ceil(X_part.nnz / 32) |
| 63 | + bpg = (bpg_x,) |
| 64 | + kernel(bpg, tpb, (X_part.indices, X_part.data, mean, nans, mask, X_part.nnz)) |
| 65 | + return cp.vstack([mean, nans.astype(cp.float64)])[None, ...] |
| 66 | + |
| 67 | + n_blocks = X.blocks.size |
| 68 | + mean, nans = X.map_blocks( |
| 69 | + __nan_mean_minor, |
| 70 | + new_axis=(1,), |
| 71 | + chunks=((1,) * n_blocks, (2,), (minor,)), |
| 72 | + dtype=cp.float64, |
| 73 | + meta=cp.array([]), |
| 74 | + ).sum(axis=0) |
| 75 | + mean /= n_features - nans |
| 76 | + return mean |
| 77 | + |
| 78 | + |
| 79 | +def _nan_mean_major_dask_sparse(X, major, minor, *, mask=None, n_features=None): |
| 80 | + from ._kernels._nan_mean_kernels import _get_nan_mean_major |
| 81 | + |
| 82 | + kernel = _get_nan_mean_major(X.dtype) |
| 83 | + kernel.compile() |
| 84 | + |
| 85 | + def __nan_mean_major(X_part): |
| 86 | + major_part = X_part.shape[0] |
| 87 | + mean = cp.zeros(major_part, dtype=cp.float64) |
| 88 | + nans = cp.zeros(major_part, dtype=cp.int32) |
| 89 | + block = (64,) |
| 90 | + grid = (major_part,) |
| 91 | + kernel( |
| 92 | + grid, |
| 93 | + block, |
| 94 | + ( |
| 95 | + X_part.indptr, |
| 96 | + X_part.indices, |
| 97 | + X_part.data, |
| 98 | + mean, |
| 99 | + nans, |
| 100 | + mask, |
| 101 | + major_part, |
| 102 | + minor, |
| 103 | + ), |
| 104 | + ) |
| 105 | + return cp.stack([mean, nans.astype(cp.float64)], axis=1) |
| 106 | + |
| 107 | + output = X.map_blocks( |
| 108 | + __nan_mean_major, |
| 109 | + chunks=(X.chunks[0], (2,)), |
| 110 | + dtype=cp.float64, |
| 111 | + meta=cp.array([]), |
| 112 | + ) |
| 113 | + mean = output[:, 0] |
| 114 | + nans = output[:, 1] |
| 115 | + mean /= n_features - nans |
| 116 | + return mean |
| 117 | + |
| 118 | + |
| 119 | +def _nan_mean_dense_dask(X, axis, *, mask, n_features): |
| 120 | + def __nan_mean_dense(X_part): |
| 121 | + X_to_use = X_part[:, mask].astype(cp.float64) |
| 122 | + sum = cp.nansum(X_to_use, axis=axis).ravel() |
| 123 | + nans = cp.sum(cp.isnan(X_to_use), axis=axis).ravel() |
| 124 | + if axis == 1: |
| 125 | + return cp.stack([sum, nans.astype(cp.float64)], axis=1) |
| 126 | + else: |
| 127 | + return cp.vstack([sum, nans.astype(cp.float64)])[None, ...] |
| 128 | + |
| 129 | + n_blocks = X.blocks.size |
| 130 | + output = X.map_blocks( |
| 131 | + __nan_mean_dense, |
| 132 | + new_axis=(1,) if axis - 1 else None, |
| 133 | + chunks=(X.chunks[0], (2,)) if axis else ((1,) * n_blocks, (2,), (X.shape[1],)), |
| 134 | + dtype=cp.float64, |
| 135 | + meta=cp.array([]), |
| 136 | + ) |
| 137 | + if axis == 0: |
| 138 | + mean, nans = output.sum(axis=0) |
| 139 | + else: |
| 140 | + mean = output[:, 0] |
| 141 | + nans = output[:, 1] |
| 142 | + mean /= n_features - nans |
| 143 | + return mean |
| 144 | + |
| 145 | + |
50 | 146 | def _nan_mean_minor(X, major, minor, *, mask=None, n_features=None): |
51 | 147 | from ._kernels._nan_mean_kernels import _get_nan_mean_minor |
52 | 148 |
|
@@ -120,6 +216,34 @@ def _nan_mean(X, axis=0, *, mask=None, n_features=None): |
120 | 216 | mean = _nan_mean_minor( |
121 | 217 | X, major, minor, mask=mask, n_features=n_features |
122 | 218 | ) |
| 219 | + elif isinstance(X, DaskArray): |
| 220 | + if isspmatrix_csr(X._meta): |
| 221 | + major, minor = X.shape |
| 222 | + if mask is None: |
| 223 | + mask = cp.ones(X.shape[1], dtype=cp.bool_) |
| 224 | + if axis == 0: |
| 225 | + n_features = major |
| 226 | + mean = _nan_mean_minor_dask_sparse( |
| 227 | + X, major, minor, mask=mask, n_features=n_features |
| 228 | + ) |
| 229 | + elif axis == 1: |
| 230 | + n_features = minor if n_features is None else n_features |
| 231 | + mean = _nan_mean_major_dask_sparse( |
| 232 | + X, major, minor, mask=mask, n_features=n_features |
| 233 | + ) |
| 234 | + else: |
| 235 | + raise ValueError("axis must be either 0 or 1") |
| 236 | + elif isinstance(X._meta, cp.ndarray): |
| 237 | + if mask is None: |
| 238 | + mask = cp.ones(X.shape[1], dtype=cp.bool_) |
| 239 | + if n_features is None: |
| 240 | + n_features = X.shape[axis] |
| 241 | + mean = _nan_mean_dense_dask(X, axis, mask=mask, n_features=n_features) |
| 242 | + # raise NotImplementedError("Dask dense arrays are not supported yet") |
| 243 | + else: |
| 244 | + raise ValueError( |
| 245 | + "Type not supported. Please provide a CuPy ndarray or a CuPy sparse matrix. Or a Dask array with a CuPy ndarray or a CuPy sparse matrix as meta." |
| 246 | + ) |
123 | 247 | else: |
124 | 248 | if mask is None: |
125 | 249 | mask = cp.ones(X.shape[1], dtype=cp.bool_) |
|
0 commit comments