Skip to content

Commit c673c02

Browse files
authored
Add dask to logreg (#413)
* 1st Iteration * update and test * add release note * update test
1 parent d6d55ac commit c673c02

4 files changed

Lines changed: 82 additions & 4 deletions

File tree

docs/release-notes/0.13.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
```
55
* Add support for aggregate operations on CSC matrices, Fortran-ordered arrays, and Dask with sparse CSR and dense matrices {pr}`395` {smaller}`S Dicks`
66
* Adds dask support for `tl.score_genes` & `tl.score_genes_cell_cycle` {pr}`408` {smaller}`S Dicks`
7+
* Adds dask support for `tl.rank_genes_groups_logreg` {pr}`413` {smaller}`S Dicks`
78

89
```{rubric} Performance
910
```

src/rapids_singlecell/tools/_rank_gene_groups.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
import pandas as pd
88

9+
from rapids_singlecell._compat import DaskArray, _meta_dense
10+
911
if TYPE_CHECKING:
1012
from collections.abc import Iterable
1113

@@ -52,7 +54,7 @@ def rank_genes_groups_logreg(
5254
groupby: str,
5355
*,
5456
groups: Literal["all"] | Iterable[str] = "all",
55-
use_raw: bool = None,
57+
use_raw: bool | None = None,
5658
reference: str = "rest",
5759
n_genes: int = None,
5860
layer: str = None,
@@ -155,7 +157,6 @@ def rank_genes_groups_logreg(
155157
# if reference is not set, then the groups listed will be compared to the rest
156158
# if reference is set, then the groups listed will be compared only to the other groups listed
157159
refname = reference
158-
from cuml.linear_model import LogisticRegression
159160

160161
reference = groups_order[0]
161162
if len(groups) == 1:
@@ -167,15 +168,28 @@ def rank_genes_groups_logreg(
167168
X = X[grouping_mask.values, :]
168169
# Indexing with a series causes issues, possibly segfault
169170

170-
grouping_logreg = grouping.cat.codes.to_numpy().astype("float32")
171+
grouping_logreg = grouping.cat.codes.to_numpy().astype(X.dtype)
171172
uniques = np.unique(grouping_logreg)
172173
for idx, cat in enumerate(uniques):
173174
grouping_logreg[np.where(grouping_logreg == cat)] = idx
174175

176+
if isinstance(X, DaskArray):
177+
import dask.array as da
178+
from cuml.dask.linear_model import LogisticRegression
179+
180+
grouping_logreg = da.from_array(
181+
grouping_logreg,
182+
chunks=(X.chunks[0]),
183+
meta=_meta_dense(grouping_logreg.dtype),
184+
)
185+
else:
186+
from cuml.linear_model import LogisticRegression
187+
188+
clf = LogisticRegression(**kwds)
189+
175190
clf = LogisticRegression(**kwds)
176191
clf.fit(X, grouping_logreg)
177192
scores_all = cp.array(clf.coef_)
178-
179193
if len(groups_order) == scores_all.shape[1]:
180194
scores_all = scores_all.T
181195
for igroup, _group in enumerate(groups_order):
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
import cupy as cp
4+
import pandas as pd
5+
import pytest
6+
from scanpy.datasets import pbmc3k_processed, pbmc68k_reduced
7+
8+
import rapids_singlecell as rsc
9+
from rapids_singlecell._testing import (
10+
as_dense_cupy_dask_array,
11+
as_sparse_cupy_dask_array,
12+
)
13+
14+
15+
@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
16+
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
17+
def test_rank_genes_groups_logreg(client, data_kind, dtype):
18+
if data_kind == "dense":
19+
adata = pbmc68k_reduced()
20+
adata.X = adata.X.astype(dtype)
21+
dask_data = adata.copy()
22+
dask_data.X = as_dense_cupy_dask_array(dask_data.X).persist()
23+
rsc.get.anndata_to_GPU(adata)
24+
groupby = "bulk_labels"
25+
read = "Dendritic"
26+
elif data_kind == "sparse":
27+
adata = pbmc3k_processed()
28+
org_var_names = adata.var_names
29+
adata = adata.raw.to_adata()
30+
adata = adata[:, org_var_names].copy()
31+
adata.X = adata.X.astype(dtype)
32+
dask_data = adata.copy()
33+
dask_data.X = as_sparse_cupy_dask_array(dask_data.X).persist()
34+
rsc.get.anndata_to_GPU(adata)
35+
groupby = "louvain"
36+
read = "B cells"
37+
38+
rsc.tl.rank_genes_groups_logreg(adata, groupby=groupby, use_raw=False)
39+
rsc.tl.rank_genes_groups_logreg(dask_data, groupby=groupby, use_raw=False)
40+
array_ad = pd.DataFrame(adata.uns["rank_genes_groups"]["scores"][read]).to_numpy()[
41+
:10
42+
]
43+
array_bd = pd.DataFrame(
44+
dask_data.uns["rank_genes_groups"]["scores"][read]
45+
).to_numpy()[:10]
46+
cp.testing.assert_allclose(array_ad, array_bd, atol=1e-3)

tests/test_rank_genes_groups_logreg.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import numpy as np
4+
import pandas as pd
45
import scanpy as sc
56

67
import rapids_singlecell as rsc
@@ -38,3 +39,19 @@ def test_rank_genes_groups_with_renamed_categories_use_rep():
3839

3940
rsc.tl.rank_genes_groups_logreg(adata, "blobs")
4041
assert not adata.uns["rank_genes_groups"]["names"][0].tolist() == ("3", "1", "0")
42+
43+
44+
def test_rank_genes_groups_with_unsorted_groups():
45+
adata = sc.datasets.blobs(n_variables=10, n_centers=5, n_observations=200)
46+
adata._sanitize()
47+
adata.rename_categories("blobs", ["Zero", "One", "Two", "Three", "Four"])
48+
bdata = adata.copy()
49+
rsc.tl.rank_genes_groups_logreg(adata, "blobs", groups=["Zero", "One", "Three"])
50+
rsc.tl.rank_genes_groups_logreg(bdata, "blobs", groups=["One", "Three", "Zero"])
51+
array_ad = pd.DataFrame(
52+
adata.uns["rank_genes_groups"]["scores"]["Three"]
53+
).to_numpy()
54+
array_bd = pd.DataFrame(
55+
bdata.uns["rank_genes_groups"]["scores"]["Three"]
56+
).to_numpy()
57+
np.testing.assert_equal(array_ad, array_bd)

0 commit comments

Comments
 (0)