Skip to content

Commit 85ea3d3

Browse files
authored
unblock CI (#695)
* unblock CI * update rapids-pre * test a10 * loosen dask * fix dask hangs * use zeros instead of nan * pin numpy * more pinning * update numba pin
1 parent 0eaeaaf commit 85ea3d3

9 files changed

Lines changed: 62 additions & 68 deletions

File tree

.cirun.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
runners:
2-
# Primary: T4. Cheapest, usually has capacity.
3-
- name: aws-gpu-runner-g4dn
2+
# Primary: A10G. ~3-4x faster than the T4 (Ampere, tensor cores, ~600 GB/s).
3+
- name: aws-gpu-runner-g5
44
cloud: aws
5-
instance_type: g4dn.xlarge
5+
instance_type: g5.xlarge
66
machine_image: ami-067a4ba2816407ee9
77
region: eu-north-1
88
preemptible:
99
- true
1010
- false
1111
labels:
1212
- cirun-aws-gpu
13-
# Fallback: A10G. cirun picks this when g4dn spot is dry.
14-
- name: aws-gpu-runner-g5
13+
# Fallback: T4. Cheapest; cirun picks this when g5 spot is dry.
14+
- name: aws-gpu-runner-g4dn
1515
cloud: aws
16-
instance_type: g5.xlarge
16+
instance_type: g4dn.xlarge
1717
machine_image: ami-067a4ba2816407ee9
1818
region: eu-north-1
1919
preemptible:

docs/release-notes/0.15.2.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
```{rubric} Features
44
```
5-
* Add `ptg.GuideAssignment.assign_mixture_model` for crispat-style Poisson-Gaussian guide assignment. The CUDA/nanobind implementation writes pertpy-compatible labels to `adata.obs` and stores both pertpy-style and crispat-style model readouts in `adata.var` {pr}`637` {smaller}`S Dicks`
5+
* Add GPU-accelerated {class}`~rapids_singlecell.ptg.GuideAssignment` (``assign_by_threshold``, ``assign_to_max_guide``, ``assign_mixture_model``), mirroring `pertpy.pp.GuideAssignment` {pr}`637` {smaller}`S Dicks`
66
* Add pseudobulk based distance metrics to {class}`~rapids_singlecell.ptg.Distance`: ``euclidean``, ``root_mean_squared_error``, ``mse``, ``mean_absolute_error``, ``pearson_distance``, ``cosine_distance``, ``r2_distance``. Matches ``pertpy.tl.Distance`` {pr}`676` {smaller}`S Dicks`
77
* Add bootstrap support (``bootstrap=True``) to the pseudobulk distance metrics of {class}`~rapids_singlecell.ptg.Distance` for ``pairwise`` and ``onesided_distances``, plus array-level ``Distance.bootstrap``. Each iteration resamples cells per group on the GPU and recomputes the group-mean distances {pr}`684` {smaller}`S Dicks`
88
* Add ``wasserstein`` metric to {class}`~rapids_singlecell.ptg.Distance` {pr}`683` {smaller}`S Dicks`

hatch.toml

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,35 @@ overrides.matrix.deps.extra-dependencies = [
3838
{ if = [
3939
"dev",
4040
], value = "scanpy @ git+https://github.com/scverse/scanpy.git" },
41+
# numpy 2.5 removed `np.row_stack`, which numba-cuda still calls -> the whole
42+
# suite fails at collection. Pin <2.5 for the prerelease-allowing envs until
43+
# numba-cuda drops the call; remove this once that lands.
44+
{ if = [
45+
"dev",
46+
"rapids_prerelease",
47+
], value = "numpy<2.5" },
48+
# UV_PRERELEASE=allow otherwise drags in mutually-incompatible numba/numba-cuda
49+
# prereleases (numba 0.63.0b1 -> removed np.trapz; numba 0.66.0rc2 + numba-cuda
50+
# 0.30.2 -> missing numba.cuda.types.NPDatetime). Pin the known-good stable pair
51+
# (numba 0.64.x + numba-cuda 0.28.x) so only RAPIDS uses nightlies.
52+
{ if = [
53+
"dev",
54+
"rapids_prerelease",
55+
], value = "numba>=0.64,<0.65" },
56+
{ if = [
57+
"dev",
58+
"rapids_prerelease",
59+
], value = "numba-cuda<0.30" },
4160
]
4261
overrides.matrix.cuda.extra-dependencies = [
43-
{ if = [ "13" ], value = "cuml-cu13<26.8" },
44-
{ if = [ "13" ], value = "cudf-cu13<26.8" },
45-
{ if = [ "13" ], value = "cugraph-cu13<26.8" },
46-
{ if = [ "13" ], value = "cuvs-cu13<26.8" },
47-
{ if = [ "12" ], value = "cuml-cu12<26.8" },
48-
{ if = [ "12" ], value = "cudf-cu12<26.8" },
49-
{ if = [ "12" ], value = "cugraph-cu12<26.8" },
50-
{ if = [ "12" ], value = "cuvs-cu12<26.8" },
62+
{ if = [ "13" ], value = "cuml-cu13<26.10" },
63+
{ if = [ "13" ], value = "cudf-cu13<26.10" },
64+
{ if = [ "13" ], value = "cugraph-cu13<26.10" },
65+
{ if = [ "13" ], value = "cuvs-cu13<26.10" },
66+
{ if = [ "12" ], value = "cuml-cu12<26.10" },
67+
{ if = [ "12" ], value = "cudf-cu12<26.10" },
68+
{ if = [ "12" ], value = "cugraph-cu12<26.10" },
69+
{ if = [ "12" ], value = "cuvs-cu12<26.10" },
5170
]
5271

5372
## For prerelease we rely on UV_PRERELEASE + nightly index; features select cu12/cu13

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ addopts = [
123123
]
124124
markers = [
125125
"gpu: tests that use a GPU (currently unused, but needs to be specified here as we import anndata.tests.helpers, which uses it)",
126+
"array_api: array-API tests (currently unused, but needs to be specified here as we import anndata.tests.helpers, which uses it)",
126127
]
127128

128129
[tool.setuptools_scm]

src/rapids_singlecell/pertpy_gpu/_guide_assignment.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@
2525
"mix_probs_0",
2626
"mix_probs_1",
2727
"threshold",
28-
"weight_Poisson",
29-
"weight_Normal",
30-
"lambda",
31-
"mu",
32-
"scale",
3328
]
3429

3530

@@ -38,9 +33,8 @@ class GuideAssignment:
3833
3934
Provides threshold-based and mixture-model-based methods for assigning
4035
cells to guide RNAs, compatible with pertpy's ``GuideAssignment`` API.
41-
The mixture model follows crispat's Poisson-Gaussian assignment rule
42-
while using batched EM on GPU instead of per-guide Pyro SVI, yielding
43-
orders-of-magnitude speedup.
36+
The mixture model fits a Poisson-Gaussian mixture per guide with batched
37+
EM on GPU, yielding orders-of-magnitude speedup.
4438
"""
4539

4640
def assign_by_threshold(
@@ -150,10 +144,9 @@ def assign_mixture_model(
150144
151145
Fits a two-component mixture (Poisson background + Gaussian signal)
152146
to the log₂-transformed non-zero counts of each guide simultaneously
153-
using batched Expectation-Maximization on GPU. Like crispat's
154-
Poisson-Gaussian assignment, the fitted model is converted to an
155-
integer raw-count threshold. The default posterior cutoff matches
156-
pertpy's crispat-style threshold rule.
147+
using batched Expectation-Maximization on GPU. The fitted model is
148+
converted to an integer raw-count threshold; the default posterior
149+
cutoff matches pertpy's threshold rule.
157150
158151
Parameters
159152
----------
@@ -255,11 +248,6 @@ def assign_mixture_model(
255248
adata.var.loc[valid_var_index, "mix_probs_0"] = pi0_cpu
256249
adata.var.loc[valid_var_index, "mix_probs_1"] = 1.0 - pi0_cpu
257250
adata.var.loc[valid_var_index, "threshold"] = thresholds_cpu
258-
adata.var.loc[valid_var_index, "weight_Poisson"] = pi0_cpu
259-
adata.var.loc[valid_var_index, "weight_Normal"] = 1.0 - pi0_cpu
260-
adata.var.loc[valid_var_index, "lambda"] = lam_cpu
261-
adata.var.loc[valid_var_index, "mu"] = mu_cpu
262-
adata.var.loc[valid_var_index, "scale"] = sigma_cpu
263251

264252
adata.obs[assigned_guides_key] = series_values
265253
return None

src/rapids_singlecell/preprocessing/_neighbors/_neighbors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def _get_connectivities_umap(
144144
"""UMAP fuzzy simplicial set connectivities."""
145145
set_op_mix_ratio = 1.0
146146
local_connectivity = 1.0
147-
X_conn = cp.empty((n_obs, 1), dtype=np.float32)
147+
148+
X_conn = cp.zeros((n_obs, 1), dtype=np.float32)
148149
logger_level = _get_logger_level(logger)
149150
connectivities = fuzzy_simplicial_set(
150151
X_conn,

tests/dask/conftest.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import dask
34
import pytest
45
from dask.distributed import Client
56
from dask_cuda import LocalCUDACluster
@@ -29,8 +30,12 @@ def dist_client(cluster):
2930
gets an isolated client (connecting to the shared cluster is cheap).
3031
"""
3132
client = Client(cluster)
32-
yield client
33-
client.close()
33+
try:
34+
yield client
35+
finally:
36+
# Always deregister the global-default client, even if the test fails,
37+
# so it can't leak into later `client`-fixture (synchronous) tests.
38+
client.close()
3439

3540

3641
@pytest.fixture(scope="function")
@@ -42,5 +47,11 @@ def client():
4247
never touch the client object. Handing them ``None`` avoids spinning up a
4348
LocalCUDACluster and skips the distributed serialization round-trips of
4449
cupy chunks, which are pure overhead on the tiny test arrays.
50+
51+
Forces the synchronous scheduler so these tests can never be hijacked by a
52+
distributed client left as the global default by an earlier ``dist_client``
53+
test (which would route ``.compute()`` through the shared cluster and stall
54+
on a GIL-holding cupy op -> the random 60s pytest-timeout hangs).
4555
"""
46-
yield None
56+
with dask.config.set(scheduler="synchronous"):
57+
yield None

tests/dask/test_dask_regress_out.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_regress_out_categorical_dask(client, data_kind, dtype):
4949

5050
dask_X = dask_data.X.compute()
5151

52-
atol = 1e-5 if dtype == "float32" else 1e-7
52+
atol = 5e-5 if dtype == "float32" else 1e-7
5353
cp.testing.assert_allclose(dask_X, ref.X, atol=atol)
5454

5555

@@ -75,7 +75,7 @@ def test_regress_out_continuous_dask(client, data_kind, dtype):
7575

7676
dask_X = dask_data.X.compute()
7777

78-
atol = 1e-5 if dtype == "float32" else 1e-7
78+
atol = 5e-5 if dtype == "float32" else 1e-7
7979
cp.testing.assert_allclose(dask_X, ref.X, atol=atol)
8080

8181

tests/pertpy/test_guide_assignment.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,6 @@ def test_mixture_model_stores_params(guide_adata: AnnData) -> None:
190190
"mix_probs_0",
191191
"mix_probs_1",
192192
"threshold",
193-
"weight_Poisson",
194-
"weight_Normal",
195-
"lambda",
196-
"mu",
197-
"scale",
198193
]:
199194
assert col in guide_adata.var.columns, f"Missing column: {col}"
200195

@@ -209,41 +204,20 @@ def test_mixture_model_stores_params(guide_adata: AnnData) -> None:
209204
means = guide_adata.var["gaussian_mean"].dropna()
210205
assert (rates < means).all(), "Poisson rate should be < Gaussian mean"
211206

212-
# Crispat-compatible aliases should mirror the pertpy-style parameter names.
213-
np.testing.assert_allclose(
214-
guide_adata.var["weight_Poisson"].dropna(),
215-
guide_adata.var["mix_probs_0"].dropna(),
216-
)
217-
np.testing.assert_allclose(
218-
guide_adata.var["weight_Normal"].dropna(),
219-
guide_adata.var["mix_probs_1"].dropna(),
220-
)
221-
np.testing.assert_allclose(
222-
guide_adata.var["lambda"].dropna(),
223-
guide_adata.var["poisson_rate"].dropna(),
224-
)
225-
np.testing.assert_allclose(
226-
guide_adata.var["mu"].dropna(),
227-
guide_adata.var["gaussian_mean"].dropna(),
228-
)
229-
np.testing.assert_allclose(
230-
guide_adata.var["scale"].dropna(),
231-
guide_adata.var["gaussian_std"].dropna(),
232-
)
233207
assert guide_adata.var["threshold"].dropna().ge(1).all()
234208

235209

236210
def test_mixture_model_overwrites_existing_var_columns(guide_adata: AnnData) -> None:
237211
guide_adata.var["threshold"] = pd.Categorical(["old"] * guide_adata.n_vars)
238-
guide_adata.var["lambda"] = "old"
212+
guide_adata.var["poisson_rate"] = "old"
239213

240214
ga = rsc.ptg.GuideAssignment()
241215
ga.assign_mixture_model(guide_adata)
242216

243217
assert pd.api.types.is_float_dtype(guide_adata.var["threshold"])
244-
assert pd.api.types.is_float_dtype(guide_adata.var["lambda"])
218+
assert pd.api.types.is_float_dtype(guide_adata.var["poisson_rate"])
245219
assert guide_adata.var["threshold"].dropna().ge(1).all()
246-
assert np.isfinite(guide_adata.var["lambda"].dropna()).all()
220+
assert np.isfinite(guide_adata.var["poisson_rate"].dropna()).all()
247221

248222

249223
def test_mixture_model_sparse_input(guide_adata_sparse: AnnData) -> None:
@@ -375,7 +349,7 @@ def test_mixture_model_skip_low_count() -> None:
375349

376350

377351
def test_mixture_model_skip_max_count_below_two() -> None:
378-
"""Crispat skips guides whose non-zero counts never reach 2 UMIs."""
352+
"""Guides whose non-zero counts never reach 2 UMIs are skipped."""
379353
X = np.zeros((50, 2), dtype=np.float32)
380354
X[:25, :] = 1.0
381355

0 commit comments

Comments
 (0)