Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test-gpu-rpr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ jobs:
environment-file: ci/environment_alpha.yml
init-shell: >-
bash
cache-environment: true
post-cleanup: 'all'

- name: Install rapids-singlecell
Expand Down
49 changes: 39 additions & 10 deletions tests/dask/test_dask_rank_logreg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import cupy as cp
import pandas as pd
import pytest
from scanpy.datasets import pbmc3k_processed, pbmc68k_reduced

Expand All @@ -12,6 +11,41 @@
)


def _compare_top_genes(result1, result2, top_n=10, min_overlap=9):
"""
Compare top N genes between two rank_genes_groups results.

Parameters
----------
result1, result2 : dict
Results from rank_genes_groups with 'names' key
top_n : int
Number of top genes to compare
min_overlap : int
Minimum number of overlapping genes required

Returns
-------
bool
True if overlap meets minimum threshold for all groups
"""
groups1 = result1["names"].dtype.names
groups2 = result2["names"].dtype.names

if set(groups1) != set(groups2):
return False

for group in groups1:
top_genes1 = set(result1["names"][group][:top_n])
top_genes2 = set(result2["names"][group][:top_n])
overlap = len(top_genes1.intersection(top_genes2))

if overlap < min_overlap:
return False

return True


@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
def test_rank_genes_groups_logreg(client, data_kind, dtype):
Expand All @@ -22,7 +56,6 @@ def test_rank_genes_groups_logreg(client, data_kind, dtype):
dask_data.X = as_dense_cupy_dask_array(dask_data.X).persist()
rsc.get.anndata_to_GPU(adata)
groupby = "bulk_labels"
read = "Dendritic"
elif data_kind == "sparse":
adata = pbmc3k_processed()
org_var_names = adata.var_names
Expand All @@ -33,14 +66,10 @@ def test_rank_genes_groups_logreg(client, data_kind, dtype):
dask_data.X = as_sparse_cupy_dask_array(dask_data.X).persist()
rsc.get.anndata_to_GPU(adata)
groupby = "louvain"
read = "B cells"

rsc.tl.rank_genes_groups_logreg(adata, groupby=groupby, use_raw=False)
rsc.tl.rank_genes_groups_logreg(dask_data, groupby=groupby, use_raw=False)
array_ad = pd.DataFrame(adata.uns["rank_genes_groups"]["scores"][read]).to_numpy()[
:10
]
array_bd = pd.DataFrame(
dask_data.uns["rank_genes_groups"]["scores"][read]
).to_numpy()[:10]
cp.testing.assert_allclose(array_ad, array_bd, atol=1e-3)

assert _compare_top_genes(
adata.uns["rank_genes_groups"], dask_data.uns["rank_genes_groups"]
)
Loading