55
66import numpy as np
77import pandas as pd
8- from anndata import AnnData , utils
8+ from anndata import AnnData
99from fast_array_utils .stats ._power import power as fau_power # TODO: upstream
1010from scipy import sparse
1111from sklearn .utils .sparsefuncs import csc_median_axis_0
1212
1313from scanpy ._compat import CSBase , CSRBase , DaskArray
1414
1515from .._utils import _resolve_axis , get_literal_vals
16+ from ._kernels import agg_sum_csc , agg_sum_csr , mean_var_csc , mean_var_csr
1617from .get import _check_mask
1718
1819if TYPE_CHECKING :
2526type AggType = ConstantDtypeAgg | Literal ["mean" , "var" ]
2627
2728
28- class Aggregate :
29+ class Aggregate [ ArrayT : np . ndarray | CSBase ] :
2930 """Functionality for generic grouping and aggregating.
3031
3132 There is currently support for count_nonzero, sum, mean, and variance.
@@ -53,19 +54,22 @@ class Aggregate:
5354 def __init__ (
5455 self ,
5556 groupby : pd .Categorical ,
56- data : Array ,
57+ data : ArrayT ,
5758 * ,
5859 mask : NDArray [np .bool ] | None = None ,
5960 ) -> None :
6061 self .groupby = groupby
6162 if (missing := groupby .isna ()).any ():
6263 mask = mask & ~ missing if mask is not None else ~ missing
6364 self .indicator_matrix = sparse_indicator (groupby , mask = mask )
65+ if isinstance (data , CSBase ):
66+ # TODO: Look into if this can be CSR and fast for dense
67+ self .indicator_matrix = self .indicator_matrix .tocsr ()
6468 self .data = data
6569
6670 groupby : pd .Categorical
67- indicator_matrix : sparse .coo_matrix
68- data : Array
71+ indicator_matrix : CSRBase | sparse .coo_array
72+ data : ArrayT
6973
7074 def count_nonzero (self ) -> NDArray [np .integer ]:
7175 """Count the number of observations in each group.
@@ -75,19 +79,30 @@ def count_nonzero(self) -> NDArray[np.integer]:
7579 Array of counts.
7680
7781 """
78- # pattern = self.data._with_data(np.broadcast_to(1, len(self.data.data)))
79- # return self.indicator_matrix @ pattern
80- return utils .asarray (self .indicator_matrix @ (self .data != 0 ))
82+ return self ._sum (data = (self .data != 0 ).astype ("uint8" ))
83+
84+ def _sum (self , data : ArrayT ):
85+ if isinstance (data , np .ndarray ):
86+ res = self .indicator_matrix @ data
87+ if isinstance (res , CSBase ):
88+ return res .toarray ()
89+ return res
90+ dtype = np .int64 if np .issubdtype (data .dtype , np .integer ) else np .float64
91+ out = np .zeros ((self .indicator_matrix .shape [0 ], data .shape [1 ]), dtype = dtype )
92+ (agg_sum_csr if isinstance (data , CSRBase ) else agg_sum_csc )(
93+ self .indicator_matrix , data , out
94+ )
95+ return out
8196
82- def sum (self ) -> Array :
97+ def sum (self ) -> np . ndarray :
8398 """Compute the sum per feature per group of observations.
8499
85100 Returns
86101 -------
87102 Array of sum.
88103
89104 """
90- return utils . asarray ( self .indicator_matrix @ self .data )
105+ return self ._sum ( self .data )
91106
92107 def mean (self ) -> Array :
93108 """Compute the mean per feature per group of observations.
@@ -97,10 +112,7 @@ def mean(self) -> Array:
97112 Array of mean.
98113
99114 """
100- return (
101- utils .asarray (self .indicator_matrix @ self .data )
102- / np .bincount (self .groupby .codes )[:, None ]
103- )
115+ return self .sum () / np .bincount (self .groupby .codes )[:, None ]
104116
105117 def mean_var (self , dof : int = 1 ) -> tuple [np .ndarray , np .ndarray ]:
106118 """Compute the count, as well as mean and variance per feature, per group of observations.
@@ -124,14 +136,17 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
124136 assert dof >= 0
125137
126138 group_counts = np .bincount (self .groupby .codes )
127- mean_ = self .mean ()
128- # sparse matrices do not support ** for elementwise power.
129- mean_sq = (
130- utils .asarray (self .indicator_matrix @ _power (self .data , 2 ))
131- / group_counts [:, None ]
132- )
133- sq_mean = mean_ ** 2
134- var_ = mean_sq - sq_mean
139+ if isinstance (self .data , np .ndarray ):
140+ mean_ = self .mean ()
141+ # sparse matrices do not support ** for elementwise power.
142+ mean_sq = self ._sum (_power (self .data , 2 )) / group_counts [:, None ]
143+ sq_mean = mean_ ** 2
144+ var_ = mean_sq - sq_mean
145+ else :
146+ mean_ , var_ = (
147+ mean_var_csr if isinstance (self .data , CSRBase ) else mean_var_csc
148+ )(self .indicator_matrix , self .data )
149+ sq_mean = mean_ ** 2
135150 # TODO: Why these values exactly? Because they are high relative to the datatype?
136151 # (unchanged from original code: https://github.com/scverse/anndata/pull/564)
137152 precision = 2 << (42 if self .data .dtype == np .float64 else 20 )
@@ -550,18 +565,15 @@ def sparse_indicator(
550565 categorical : pd .Categorical ,
551566 * ,
552567 mask : NDArray [np .bool ] | None = None ,
553- weight : NDArray [np .floating ] | None = None ,
554- ) -> sparse .coo_matrix :
555- if mask is not None and weight is None :
556- weight = mask .astype (np .float32 )
557- elif mask is not None and weight is not None :
558- weight = mask * weight
559- elif mask is None and weight is None :
560- weight = np .broadcast_to (1.0 , len (categorical ))
568+ ) -> sparse .coo_array :
569+ # TODO: why is this float64. This is a scanpy 2.0 problem maybe?
570+ mask = (
571+ np .broadcast_to (1.0 , len (categorical )) if mask is None else mask .astype ("uint8" )
572+ )
561573 # can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked
562- codes = categorical . codes if mask is None else np .where (mask , categorical .codes , 0 )
563- a = sparse .coo_matrix (
564- (weight , (codes , np .arange (len (categorical )))),
574+ codes = np .where (mask , categorical .codes , 0 )
575+ a = sparse .coo_array (
576+ (mask , (codes , np .arange (len (categorical )))),
565577 shape = (len (categorical .categories ), len (categorical )),
566578 )
567579 return a
0 commit comments