Skip to content

Commit 5fd8aa6

Browse files
authored
Fix rank_genes_groups with groups parameter (#651)
* fix ttest groups * add release note * add more tests
1 parent 4aeefb6 commit 5fd8aa6

6 files changed

Lines changed: 161 additions & 12 deletions

File tree

docs/release-notes/0.15.1.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
### 0.15.1 {small}`the-future`
2+
3+
```{rubric} Bug fixes
4+
```
5+
* Fixes `tl.rank_genes_groups` returning NaN/zero `logfoldchanges`/`pvals` with `groups=[subset]` and `reference='rest'` {pr}`651` {smaller}`S Dicks`

docs/release-notes/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55

66
## Version 0.15.0
7+
```{include} /release-notes/0.15.1.md
8+
```
79
```{include} /release-notes/0.15.0.md
810
```
911

src/rapids_singlecell/tools/_rank_genes_groups/_core.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,33 +174,45 @@ def _basic_stats(self) -> None:
174174
cat_to_idx = {str(name): i for i, name in enumerate(cat_names)}
175175
order = [cat_to_idx[str(name)] for name in self.groups_order]
176176

177+
# Aggregate returns stats per ALL categories. Slice to selected groups
178+
# for per-group means/vars; keep the all-category arrays for "rest"
179+
# stats so the totals stay correct when ``groups`` is a strict subset.
180+
sums_all = result["sum"]
181+
sq_sums_all = result["sq_sum"]
182+
nnz_all = result["count_nonzero"] if self.comp_pts else None
183+
177184
n = cp.asarray(self.group_sizes, dtype=cp.float64)[:, None]
178-
sums = result["sum"][order]
179-
sq_sums = result["sq_sum"][order]
185+
sums = sums_all[order]
186+
sq_sums = sq_sums_all[order]
180187

181188
# Compute means and variances from raw sums (all on GPU)
182189
means = sums / n
183190
group_ss = sq_sums - n * means**2
184191
vars_ = cp.maximum(group_ss / cp.maximum(n - 1, 1), 0)
185192

186193
if self.comp_pts:
187-
pts = result["count_nonzero"][order].astype(cp.float64) / n
194+
pts = nnz_all[order].astype(cp.float64) / n
188195
else:
189196
pts = None
190197

191-
# Compute rest statistics if reference='rest'
198+
# Compute rest statistics if reference='rest' — "rest" means every
199+
# cell in ``groupby`` not in this group, including cells in
200+
# categories that weren't selected via ``groups=``.
192201
if self.ireference is None:
193-
n_rest = n.sum() - n
194-
means_rest = (sums.sum(axis=0) - sums) / n_rest
195-
rest_ss = (sq_sums.sum(axis=0) - sq_sums) - n_rest * means_rest**2
202+
n_total = agg.n_cells.sum()
203+
n_rest = n_total - n
204+
means_rest = (sums_all.sum(axis=0) - sums) / n_rest
205+
rest_ss = (sq_sums_all.sum(axis=0) - sq_sums) - n_rest * means_rest**2
196206
vars_rest = cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0)
197207

198208
self.means_rest = cp.asnumpy(means_rest)
199209
self.vars_rest = cp.asnumpy(vars_rest)
200210

201211
if self.comp_pts:
202-
total_count = (pts * n).sum(axis=0)
203-
self.pts_rest = cp.asnumpy((total_count - pts * n) / n_rest)
212+
nnz_total = nnz_all.sum(axis=0)
213+
self.pts_rest = cp.asnumpy(
214+
(nnz_total - nnz_all[order]).astype(cp.float64) / n_rest
215+
)
204216
else:
205217
self.pts_rest = None
206218
else:

tests/test_rank_genes_groups_ttest.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,63 @@ def test_rank_genes_groups_ttest_subset_and_bonferroni(reference, method):
149149
assert np.all(adjusted <= 1.0)
150150

151151

152+
@pytest.mark.parametrize(
153+
("groups", "reference"),
154+
[
155+
(["0"], "rest"),
156+
(["0", "2"], "rest"),
157+
(["0"], "1"),
158+
(["0", "2"], "1"),
159+
],
160+
)
161+
@pytest.mark.parametrize("method", ["t-test", "t-test_overestim_var"])
162+
def test_rank_genes_groups_ttest_subset_matches_scanpy(method, groups, reference):
163+
np.random.seed(42)
164+
adata_gpu = sc.datasets.blobs(n_variables=8, n_centers=5, n_observations=200)
165+
adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category")
166+
adata_cpu = adata_gpu.copy()
167+
168+
rsc.tl.rank_genes_groups(
169+
adata_gpu,
170+
"blobs",
171+
method=method,
172+
groups=groups,
173+
reference=reference,
174+
use_raw=False,
175+
)
176+
sc.tl.rank_genes_groups(
177+
adata_cpu,
178+
"blobs",
179+
method=method,
180+
groups=groups,
181+
reference=reference,
182+
use_raw=False,
183+
)
184+
185+
gpu_result = adata_gpu.uns["rank_genes_groups"]
186+
cpu_result = adata_cpu.uns["rank_genes_groups"]
187+
188+
assert gpu_result["names"].dtype.names == cpu_result["names"].dtype.names
189+
for group in gpu_result["names"].dtype.names:
190+
gpu_names = list(gpu_result["names"][group])
191+
cpu_names = list(cpu_result["names"][group])
192+
for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"):
193+
gpu_map = dict(
194+
zip(gpu_names, np.asarray(gpu_result[field][group], dtype=float))
195+
)
196+
cpu_map = dict(
197+
zip(cpu_names, np.asarray(cpu_result[field][group], dtype=float))
198+
)
199+
for gene, gpu_val in gpu_map.items():
200+
np.testing.assert_allclose(
201+
gpu_val,
202+
cpu_map[gene],
203+
rtol=1e-6,
204+
atol=1e-8,
205+
err_msg=f"{field} mismatch for gene {gene} group {group}",
206+
)
207+
208+
152209
@pytest.mark.parametrize(
153210
"reference_before,reference_after",
154211
[("rest", "rest"), ("1", "One")],

tests/test_rank_genes_groups_wilcoxon.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,69 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference):
148148
assert np.all(adjusted <= 1.0)
149149

150150

151+
@pytest.mark.parametrize(
152+
("groups", "reference"),
153+
[
154+
(["0"], "rest"),
155+
(["0", "2"], "rest"),
156+
(["0"], "1"),
157+
(["0", "2"], "1"),
158+
],
159+
)
160+
@pytest.mark.parametrize("tie_correct", [False, True])
161+
@pytest.mark.parametrize("pre_load", [False, True])
162+
def test_rank_genes_groups_wilcoxon_subset_matches_scanpy(
163+
groups, reference, tie_correct, pre_load
164+
):
165+
np.random.seed(42)
166+
adata_gpu = sc.datasets.blobs(n_variables=8, n_centers=5, n_observations=200)
167+
adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category")
168+
adata_cpu = adata_gpu.copy()
169+
170+
rsc.tl.rank_genes_groups(
171+
adata_gpu,
172+
"blobs",
173+
method="wilcoxon",
174+
groups=groups,
175+
reference=reference,
176+
use_raw=False,
177+
tie_correct=tie_correct,
178+
pre_load=pre_load,
179+
)
180+
sc.tl.rank_genes_groups(
181+
adata_cpu,
182+
"blobs",
183+
method="wilcoxon",
184+
groups=groups,
185+
reference=reference,
186+
use_raw=False,
187+
tie_correct=tie_correct,
188+
)
189+
190+
gpu_result = adata_gpu.uns["rank_genes_groups"]
191+
cpu_result = adata_cpu.uns["rank_genes_groups"]
192+
193+
assert gpu_result["names"].dtype.names == cpu_result["names"].dtype.names
194+
for group in gpu_result["names"].dtype.names:
195+
gpu_names = list(gpu_result["names"][group])
196+
cpu_names = list(cpu_result["names"][group])
197+
for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"):
198+
gpu_map = dict(
199+
zip(gpu_names, np.asarray(gpu_result[field][group], dtype=float))
200+
)
201+
cpu_map = dict(
202+
zip(cpu_names, np.asarray(cpu_result[field][group], dtype=float))
203+
)
204+
for gene, gpu_val in gpu_map.items():
205+
np.testing.assert_allclose(
206+
gpu_val,
207+
cpu_map[gene],
208+
rtol=1e-6,
209+
atol=1e-8,
210+
err_msg=f"{field} mismatch for gene {gene} group {group}",
211+
)
212+
213+
151214
@pytest.mark.parametrize(
152215
"reference_before,reference_after",
153216
[("rest", "rest"), ("1", "One")],

tests/test_rank_genes_groups_wilcoxon_binned.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ def test_reference_matches_exact(self, adata_blobs, reference, groups):
264264
@pytest.mark.parametrize(
265265
("reference", "groups"),
266266
[
267+
pytest.param("rest", ["0"], id="rest_single_group"),
267268
pytest.param("rest", ["0", "2"], id="rest_group_subset"),
269+
pytest.param("1", ["0"], id="ref_single_group"),
268270
pytest.param("1", ["0", "1", "2"], id="ref_group_subset"),
269271
],
270272
)
@@ -300,9 +302,17 @@ def test_group_subset_matches_all_groups(self, adata_blobs, reference, groups):
300302
)
301303

302304
for group in result_sub["names"].dtype.names:
303-
scores_all = np.asarray(result_all["scores"][group], dtype=float)
304-
scores_sub = np.asarray(result_sub["scores"][group], dtype=float)
305-
np.testing.assert_allclose(scores_all, scores_sub, rtol=1e-10)
305+
assert tuple(result_all["names"][group]) == tuple(
306+
result_sub["names"][group]
307+
), f"names mismatch for group {group}"
308+
for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"):
309+
np.testing.assert_allclose(
310+
np.asarray(result_all[field][group], dtype=float),
311+
np.asarray(result_sub[field][group], dtype=float),
312+
rtol=1e-10,
313+
atol=1e-12,
314+
err_msg=f"{field} mismatch for group {group}",
315+
)
306316

307317
@pytest.mark.parametrize("reference", ["rest", "1"])
308318
def test_unsorted_groups(self, adata_blobs, reference):

0 commit comments

Comments
 (0)