88)
99
1010import cupy as cp
11- import numpy as np
1211from anndata import AnnData
1312from cupyx .scipy import sparse as cp_sparse
1413from scanpy ._utils import _resolve_axis
1514from scanpy .get ._aggregated import _combine_categories
1615
16+ from rapids_singlecell ._compat import (
17+ DaskArray ,
18+ _meta_dense ,
19+ )
1720from rapids_singlecell .get import _check_mask
1821from rapids_singlecell .preprocessing ._utils import _check_gpu_X
1922
2023if TYPE_CHECKING :
2124 from collections .abc import Collection , Iterable
2225
26+ import numpy as np
2327 import pandas as pd
2428 from numpy .typing import NDArray
2529
@@ -52,7 +56,7 @@ def __init__(
5256 ) -> None :
5357 self .mask = mask
5458 self .groupby = cp .array (groupby .codes , dtype = cp .int32 )
55- self .n_cells = cp .array (np .bincount (groupby . codes ), dtype = cp .float64 ).reshape (
59+ self .n_cells = cp .array (cp .bincount (self . groupby ), dtype = cp .float64 ).reshape (
5660 - 1 , 1
5761 )
5862 self .data = data
@@ -66,24 +70,131 @@ def _get_mask(self):
6670 else :
6771 return cp .ones (self .data .shape [0 ], dtype = bool )
6872
73+ def count_mean_var_dask (self , dof : int = 1 , split_every : int = 2 ):
74+ """
75+ This function is used to calculate the sum, mean, and variance of the data matrix.
76+ It automatically detects sparse vs dense matrices and uses the appropriate
77+ CUDA kernel for aggregation.
78+ """
79+ import dask .array as da
80+
81+ assert dof >= 0
82+ from ._kernels ._aggr_kernels import (
83+ _get_aggr_dense_kernel_C ,
84+ _get_aggr_sparse_kernel ,
85+ )
86+
87+ if isinstance (self .data ._meta , cp .ndarray ):
88+ kernel = _get_aggr_dense_kernel_C (self .data .dtype )
89+ is_sparse = False
90+ else :
91+ kernel = _get_aggr_sparse_kernel (self .data .dtype )
92+ is_sparse = True
93+
94+ kernel .compile ()
95+ n_groups = self .n_cells .shape [0 ]
96+
97+ def __aggregate_dask (X_part , mask_part , groupby_part ):
98+ out = cp .zeros ((1 , 3 , n_groups , self .data .shape [1 ]), dtype = cp .float64 )
99+ threads_per_block = 512
100+
101+ if is_sparse :
102+ # Sparse matrix kernel parameters
103+ grid = (X_part .shape [0 ],)
104+ kernel_args = (
105+ X_part .indptr ,
106+ X_part .indices ,
107+ X_part .data ,
108+ )
109+ else :
110+ # Dense matrix kernel parameters
111+ N = X_part .shape [0 ] * X_part .shape [1 ]
112+
113+ blocks = min (
114+ (N + threads_per_block - 1 ) // threads_per_block ,
115+ cp .cuda .Device ().attributes ["MultiProcessorCount" ] * 8 ,
116+ )
117+ grid = (blocks ,)
118+ kernel_args = (X_part ,)
119+
120+ kernel (
121+ grid ,
122+ (threads_per_block ,),
123+ (
124+ * kernel_args ,
125+ out ,
126+ groupby_part ,
127+ mask_part ,
128+ X_part .shape [0 ],
129+ X_part .shape [1 ],
130+ n_groups ,
131+ ),
132+ )
133+ return out
134+
135+ # Prepare Dask arrays
136+ mask = self ._get_mask ()
137+ mask_dask = da .from_array (
138+ mask , chunks = (self .data .chunks [0 ]), meta = _meta_dense (mask .dtype )
139+ )
140+ groupby_dask = da .from_array (
141+ self .groupby ,
142+ chunks = (self .data .chunks [0 ]),
143+ meta = _meta_dense (self .groupby .dtype ),
144+ )
145+
146+ # Apply aggregation across all blocks
147+ out = da .map_blocks (
148+ __aggregate_dask ,
149+ self .data ,
150+ mask_dask [..., None ],
151+ groupby_dask [..., None ],
152+ meta = cp .empty ([], dtype = cp .float64 ),
153+ dtype = cp .float64 ,
154+ new_axis = (1 , 2 ),
155+ chunks = (
156+ (1 ,) * self .data .blocks .size ,
157+ (3 ,),
158+ (n_groups ,),
159+ (self .data .shape [1 ],),
160+ ),
161+ )
162+
163+ # Compute final aggregated results
164+ out = out .sum (axis = 0 , split_every = split_every ).compute ()
165+ sums , counts , sq_sums = out [0 ], out [1 ], out [2 ]
166+
167+ # Calculate statistics
168+ counts = counts .astype (cp .int32 )
169+ means = sums / self .n_cells
170+ var = sq_sums / self .n_cells - cp .power (means , 2 )
171+ var *= self .n_cells / (self .n_cells - dof )
172+
173+ return {"mean" : means , "var" : var , "sum" : sums , "count_nonzero" : counts }
174+
69175 def count_mean_var_sparse (self , dof : int = 1 ):
70176 """
71177 This function is used to calculate the sum, mean, and variance of the sparse data matrix.
72178 It uses a custom cuda-kernel to perform the aggregation.
73179 """
74180
75181 assert dof >= 0
76- from ._kernels ._aggr_kernels import _get_aggr_sparse_kernel
182+ from ._kernels ._aggr_kernels import (
183+ _get_aggr_sparse_kernel ,
184+ _get_aggr_sparse_kernel_csc ,
185+ )
186+
187+ out = cp .zeros (
188+ (3 , self .n_cells .shape [0 ] * self .data .shape [1 ]), dtype = cp .float64
189+ )
77190
191+ block = (512 ,)
78192 if self .data .format == "csc" :
79- self .data = self .data .tocsr ()
80- means = cp .zeros ((self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .float64 )
81- var = cp .zeros ((self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .float64 )
82- sums = cp .zeros ((self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .float64 )
83- counts = cp .zeros ((self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .int32 )
84- block = (128 ,)
85- grid = (self .data .shape [0 ],)
86- aggr_kernel = _get_aggr_sparse_kernel (self .data .dtype )
193+ grid = (self .data .shape [1 ],)
194+ aggr_kernel = _get_aggr_sparse_kernel_csc (self .data .dtype )
195+ else :
196+ grid = (self .data .shape [0 ],)
197+ aggr_kernel = _get_aggr_sparse_kernel (self .data .dtype )
87198 mask = self ._get_mask ()
88199 aggr_kernel (
89200 grid ,
@@ -92,23 +203,24 @@ def count_mean_var_sparse(self, dof: int = 1):
92203 self .data .indptr ,
93204 self .data .indices ,
94205 self .data .data ,
95- counts ,
96- sums ,
97- means ,
98- var ,
206+ out ,
99207 self .groupby ,
100- self .n_cells ,
101208 mask ,
102209 self .data .shape [0 ],
103210 self .data .shape [1 ],
211+ self .n_cells .shape [0 ],
104212 ),
105213 )
106-
107- var = var - cp .power (means , 2 )
214+ sums , counts , sq_sums = out [0 , :], out [1 , :], out [2 , :]
215+ sums = sums .reshape (self .n_cells .shape [0 ], self .data .shape [1 ])
216+ sq_sums = sq_sums .reshape (self .n_cells .shape [0 ], self .data .shape [1 ])
217+ counts = counts .reshape (self .n_cells .shape [0 ], self .data .shape [1 ])
218+ counts = counts .astype (cp .int32 )
219+ means = sums / self .n_cells
220+ var = sq_sums / self .n_cells - means ** 2
108221 var *= self .n_cells / (self .n_cells - dof )
109222
110223 results = {"sum" : sums , "count_nonzero" : counts , "mean" : means , "var" : var }
111-
112224 return results
113225
114226 def count_mean_var_sparse_sparse (self , funcs , dof : int = 1 ):
@@ -275,34 +387,44 @@ def count_mean_var_dense(self, dof: int = 1):
275387 """
276388
277389 assert dof >= 0
278- from ._kernels ._aggr_kernels import _get_aggr_dense_kernel
390+ from ._kernels ._aggr_kernels import (
391+ _get_aggr_dense_kernel_C ,
392+ _get_aggr_dense_kernel_F ,
393+ )
279394
280- means = cp .zeros ((self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .float64 )
281- var = cp .zeros ((self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .float64 )
282- sums = cp .zeros ((self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .float64 )
283- counts = cp .zeros ((self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .int32 )
284- block = (128 ,)
285- grid = (self .data .shape [0 ],)
286- aggr_kernel = _get_aggr_dense_kernel (self .data .dtype )
395+ out = cp .zeros ((3 , self .n_cells .shape [0 ], self .data .shape [1 ]), dtype = cp .float64 )
396+
397+ N = self .data .shape [0 ] * self .data .shape [1 ]
398+ threads_per_block = 512
399+ blocks = min (
400+ (N + threads_per_block - 1 ) // threads_per_block ,
401+ cp .cuda .Device ().attributes ["MultiProcessorCount" ] * 8 ,
402+ )
403+ if self .data .flags .c_contiguous :
404+ aggr_kernel = _get_aggr_dense_kernel_C (self .data .dtype )
405+ else :
406+ aggr_kernel = _get_aggr_dense_kernel_F (self .data .dtype )
287407 mask = self ._get_mask ()
288408 aggr_kernel (
289- grid ,
290- block ,
409+ ( blocks ,) ,
410+ ( threads_per_block ,) ,
291411 (
292- self .data .data ,
293- counts ,
294- sums ,
295- means ,
296- var ,
412+ self .data ,
413+ out ,
297414 self .groupby ,
298- self .n_cells ,
299415 mask ,
300416 self .data .shape [0 ],
301417 self .data .shape [1 ],
418+ self .n_cells .shape [0 ],
302419 ),
303420 )
304-
305- var = var - cp .power (means , 2 )
421+ sums , counts , sq_sums = out [0 ], out [1 ], out [2 ]
422+ sums = sums .reshape (self .n_cells .shape [0 ], self .data .shape [1 ])
423+ counts = counts .reshape (self .n_cells .shape [0 ], self .data .shape [1 ])
424+ sq_sums = sq_sums .reshape (self .n_cells .shape [0 ], self .data .shape [1 ])
425+ counts = counts .astype (cp .int32 )
426+ means = sums / self .n_cells
427+ var = sq_sums / self .n_cells - cp .power (means , 2 )
306428 var *= self .n_cells / (self .n_cells - dof )
307429
308430 results = {"sum" : sums , "count_nonzero" : counts , "mean" : means , "var" : var }
@@ -322,6 +444,7 @@ def aggregate(
322444 obsm : str | None = None ,
323445 varm : str | None = None ,
324446 return_sparse : bool = False ,
447+ ** kwargs ,
325448) -> AnnData :
326449 """\
327450 Aggregate data matrix based on some categorical grouping.
@@ -416,11 +539,9 @@ def aggregate(
416539 elif axis == 1 :
417540 # i.e., all of `varm`, `obsm`, `layers` are None so we use `X` which must be transposed
418541 data = data .T
419- _check_gpu_X (data )
542+ _check_gpu_X (data , allow_dask = True )
420543 dim_df = getattr (adata , axis_name )
421544 categorical , new_label_df = _combine_categories (dim_df , by )
422- # Actual computation
423-
424545 groupby = Aggregate (groupby = categorical , data = data , mask = mask )
425546
426547 funcs = set ([func ] if isinstance (func , str ) else func )
@@ -429,6 +550,15 @@ def aggregate(
429550
430551 if isinstance (data , cp .ndarray ):
431552 result = groupby .count_mean_var_dense (dof )
553+ elif isinstance (data , DaskArray ):
554+ if "split_every" in kwargs :
555+ assert isinstance (kwargs ["split_every" ], int )
556+ assert kwargs ["split_every" ] > 0
557+ split_every = kwargs ["split_every" ]
558+ else :
559+ split_every = 2
560+ result = groupby .count_mean_var_dask (dof , split_every = split_every )
561+
432562 else :
433563 if return_sparse :
434564 result = groupby .count_mean_var_sparse_sparse (funcs , dof )
0 commit comments