diff --git a/src/rapids_singlecell/decoupler_gpu/_helper/_Method.py b/src/rapids_singlecell/decoupler_gpu/_helper/_Method.py index 4e2e06a4..fe06ade9 100644 --- a/src/rapids_singlecell/decoupler_gpu/_helper/_Method.py +++ b/src/rapids_singlecell/decoupler_gpu/_helper/_Method.py @@ -83,6 +83,7 @@ def __call__( bsize: int | float = 5000, verbose: bool = False, pre_load: bool = False, + adj_pv_gpu: bool = False, **kwargs, ): return _run( @@ -98,5 +99,6 @@ def __call__( bsize=bsize, verbose=verbose, pre_load=pre_load, + adj_pv_gpu=adj_pv_gpu, **kwargs, ) diff --git a/src/rapids_singlecell/decoupler_gpu/_helper/_docs.py b/src/rapids_singlecell/decoupler_gpu/_helper/_docs.py index 10571750..4faf3e8d 100644 --- a/src/rapids_singlecell/decoupler_gpu/_helper/_docs.py +++ b/src/rapids_singlecell/decoupler_gpu/_helper/_docs.py @@ -88,6 +88,10 @@ tval Whether to return the t-value (`tval=True`) the coefficient of the fitted model (`tval=False`).""" +_adj_pv_gpu = """\ +adj_pv_gpu + Whether to use GPU for adjusting p-values.""" + _params = f"""\ Parameters ---------- @@ -99,7 +103,8 @@ {_empty} {_bsize} {_verbose} -{_pre_load}""" +{_pre_load} +{_adj_pv_gpu}""" docs = DocstringProcessor( net=_net, diff --git a/src/rapids_singlecell/decoupler_gpu/_helper/_pv.py b/src/rapids_singlecell/decoupler_gpu/_helper/_pv.py index 7e8c3fe7..08180a75 100644 --- a/src/rapids_singlecell/decoupler_gpu/_helper/_pv.py +++ b/src/rapids_singlecell/decoupler_gpu/_helper/_pv.py @@ -1,42 +1,209 @@ from __future__ import annotations import cupy as cp +import numba as nb +import numpy as np +# Reverse cumulative min along the last axis, per row (float64) +_rev_cummin64 = cp.RawKernel( + r""" +extern "C" __global__ +void rev_cummin64(const double* __restrict__ x, + double* __restrict__ y, + const int n_rows, + const int m) +{ + int r = blockDim.x * blockIdx.x + threadIdx.x; + if (r >= n_rows) return; -def fdr_bh_axis1(ps): + const double* xr = x + (size_t)r * m; + double* yr = y + (size_t)r * m; + + double cur = xr[m - 1]; + yr[m - 1] = cur; + + // right -> left + for (int j = m - 2; j >= 0; --j) { + double v = xr[j]; + cur = (v < cur) ? v : cur; + yr[j] = cur; + } +} +""", + "rev_cummin64", +) + + +def fdr_bh_axis1_cupy_optimized(ps, *, mem_gb: float = 4.0) -> cp.ndarray: """ - Benjamini–Hochberg adjusted p-values along axis=1 (rows). - ps: cupy.ndarray (n_rows, n_tests), values in [0, 1]. - Returns: cupy.ndarray of same shape. + Benjamini–Hochberg adjusted p-values along axis=1 (rows), GPU-optimized. + - Keeps values in float64 for CPU-equivalence. + - Uses int32 index buffers where safe. + - Processes rows in chunks to cap peak memory. + - Avoids a second argsort by building inverse permutation via scatter. + - Uses a custom kernel for reverse cumulative min. + + Parameters + ---------- + ps : cupy.ndarray, shape (n_rows, n_tests), dtype float64 (or castable) + P-values in [0, 1]. + mem_gb : float + Approx memory budget for temporaries per batch (default 4 GB). + + Returns + ------- + cupy.ndarray, shape (n_rows, n_tests), dtype float64 + BH-adjusted p-values on device. """ ps = cp.asarray(ps, dtype=cp.float64) if ps.ndim != 2: - raise ValueError("ps must be 2D (n_rows, n_tests) for axis=1.") - if not cp.issubdtype(ps.dtype, cp.number): - raise ValueError("`ps` must be numeric.") - if not bool((ps == cp.clip(ps, 0, 1)).all()): + raise ValueError("ps must be 2D (n_rows, n_tests).") + + # Bounds check without host transfer + if not bool((ps == cp.clip(ps, 0.0, 1.0)).all()): raise ValueError("`ps` must be within [0, 1].") - n_rows, m = ps.shape + n_rows, m = map(int, ps.shape) if m <= 1: return ps.copy() - # sort each row - order = cp.argsort(ps, axis=1) - ps_sorted = cp.take_along_axis(ps, order, axis=1) + # Precompute BH scale in float64: scale[j] = m / (j+1) + scale = m / cp.arange(1, m + 1, dtype=cp.float64) + + # Batch size: very conservative estimate of bytes/row for temps: + # ps_chunk(8) + order(4) + ps_sorted(8) + ps_bh(8) + ps_mon(8) + inv_order(4) + flat_idx(8) ~= 48 bytes * m + bytes_per_row = 48 * m + mem_bytes = int(mem_gb * (1024**3)) + B = max(1, mem_bytes // max(bytes_per_row, 1)) + + out = cp.empty_like(ps, dtype=cp.float64) + + threads = 256 # for the rev_cummin kernel + for s in range(0, n_rows, B): + e = min(n_rows, s + B) + R = e - s + + ps_chunk = ps[s:e] # (R, m) float64 + + # 1) per-row argsort (ascending); keep indices as int32 + order = cp.argsort(ps_chunk, axis=1).astype(cp.int32, copy=False) # (R, m) + + # 2) gather sorted values with flattened indexing (faster at large sizes) + row_base = (cp.arange(R, dtype=cp.int32) * m)[:, None] # (R,1) + flat_idx = (order.astype(cp.int32) + row_base).ravel() # (R*m,) + ps_sorted = ps_chunk.ravel()[flat_idx].reshape(R, m) # (R, m) float64 + + # 3) BH scaling + ps_bh = ps_sorted * scale # (R, m) float64 + + # 4) reverse cumulative min via custom kernel + ps_mon = cp.empty_like(ps_bh) + blocks = (R + threads - 1) // threads + _rev_cummin64((blocks,), (threads,), (ps_bh, ps_mon, R, m)) + + # 5) build inverse permutation without argsort (scatter) + inv_order = cp.empty_like(order, dtype=cp.int32) # (R, m) int32 + # Broadcast-safe RHS; no R*m materialization for RHS + inv_order[cp.arange(R)[:, None], order] = cp.arange(m, dtype=cp.int32) + + # 6) unsort back via flattened gather + flat_idx2 = (inv_order.astype(cp.int32) + row_base).ravel() + out[s:e] = ps_mon.ravel()[flat_idx2].reshape(R, m) + + # 7) clamp to [0,1] + return cp.clip(out, 0.0, 1.0).get().astype(np.float32) + + +@nb.jit(nopython=True, cache=True) +def _fdr_bh_single_row(ps_row, m): + """ + Apply Benjamini-Hochberg correction to a single row. + """ + # Sort the row and get indices + order = np.argsort(ps_row) + ps_sorted = ps_row[order] # BH scale: p_(i) * m / i - i = cp.arange(1, m + 1, dtype=ps.dtype) - ps_bh = ps_sorted * (m / i) + ps_bh = np.empty_like(ps_sorted, dtype=np.float64) + for i in range(m): + ps_bh[i] = ps_sorted[i] * (m / (i + 1)) + + # Reverse cumulative min + ps_rev = np.empty_like(ps_bh, dtype=np.float64) + for i in range(m): + ps_rev[i] = ps_bh[m - 1 - i] - # reverse cumulative min across columns (no ufunc.accumulate in CuPy) - ps_rev = ps_bh[:, ::-1].copy() for j in range(1, m): - ps_rev[:, j] = cp.minimum(ps_rev[:, j], ps_rev[:, j - 1]) - ps_monotone = ps_rev[:, ::-1] + ps_rev[j] = min(ps_rev[j], ps_rev[j - 1]) - # unsort back to original column order - inv_order = cp.argsort(order, axis=1) - ps_adj = cp.take_along_axis(ps_monotone, inv_order, axis=1) + # Reverse back + ps_monotone = np.empty_like(ps_rev, dtype=np.float64) + for i in range(m): + ps_monotone[i] = ps_rev[m - 1 - i] + + # Unsort back to original order + ps_adj = np.empty_like(ps_monotone, dtype=np.float64) + for i in range(m): + ps_adj[order[i]] = ps_monotone[i] + + # Clip to [0, 1] + for i in range(m): + ps_adj[i] = max(0.0, min(1.0, ps_adj[i])) + + return ps_adj + + +@nb.jit(nopython=True, parallel=True, cache=True) +def _fdr_bh_parallel(ps, m): + """ + Apply Benjamini-Hochberg correction to all rows in parallel. + """ + n_rows = ps.shape[0] + result = np.empty_like(ps, dtype=np.float64) + + for i in nb.prange(n_rows): + result[i] = _fdr_bh_single_row(ps[i], m) + + return result + + +def fdr_bh_axis1_numba(ps): + """ + Benjamini–Hochberg adjusted p-values along axis=1 (rows). + ps: numpy.ndarray (n_rows, n_tests), values in [0, 1]. + Returns: numpy.ndarray of same shape. + """ + ps = np.asarray(ps, dtype=np.float64) + if ps.ndim != 2: + raise ValueError("ps must be 2D (n_rows, n_tests) for axis=1.") + if not np.issubdtype(ps.dtype, np.number): + raise ValueError("`ps` must be numeric.") + if not np.all((ps >= 0) & (ps <= 1)): + raise ValueError("`ps` must be within [0, 1].") + + n_rows, m = ps.shape + if m <= 1: + return ps.copy().astype(np.float32) + + # Process each row in parallel + result = _fdr_bh_parallel(ps, m) + return result.astype(np.float32) + + +def fdr_bh_axis1(ps: np.ndarray, *, if_gpu: bool = False) -> np.ndarray: + """ + Main function with CPU/GPU selection and optimization options. + + Parameters: + ----------- + ps : array-like + Input p-values array + if_gpu : bool, default=False + Whether to use GPU (CuPy) or CPU (Numba) + """ - return cp.clip(ps_adj, 0, 1).get().astype(cp.float32) + if if_gpu: + out = fdr_bh_axis1_cupy_optimized(ps) + else: + out = fdr_bh_axis1_numba(ps) + return out diff --git a/src/rapids_singlecell/decoupler_gpu/_helper/_run.py b/src/rapids_singlecell/decoupler_gpu/_helper/_run.py index 92b58cb0..5d2c02d4 100644 --- a/src/rapids_singlecell/decoupler_gpu/_helper/_run.py +++ b/src/rapids_singlecell/decoupler_gpu/_helper/_run.py @@ -100,6 +100,7 @@ def _run( bsize: int | float = 5000, verbose: bool = False, pre_load: bool = False, + adj_pv_gpu: bool = False, **kwargs, ) -> tuple[pd.DataFrame, pd.DataFrame] | AnnData | None: _log(f"{name} - Running {name}", level="info", verbose=verbose) @@ -173,7 +174,7 @@ def _run( pv = pd.DataFrame(pv, index=obs, columns=sources) if name != "mlm": _log(f"{name} - adjusting p-values by FDR", level="info", verbose=verbose) - pv.loc[:, :] = fdr_bh_axis1(pv.values) + pv.loc[:, :] = fdr_bh_axis1(pv.values, if_gpu=adj_pv_gpu) else: pv = None _log(f"{name} - done", level="info", verbose=verbose) diff --git a/tests/decoupler/test_pv.py b/tests/decoupler/test_pv.py index 1a1a4804..26a4c6da 100644 --- a/tests/decoupler/test_pv.py +++ b/tests/decoupler/test_pv.py @@ -1,19 +1,22 @@ from __future__ import annotations import numpy as np +import pytest import scipy.stats as sts import rapids_singlecell.decoupler_gpu as dc from rapids_singlecell.decoupler_gpu._helper._pv import fdr_bh_axis1 +@pytest.mark.parametrize("if_gpu", [True, False]) def test_func_mlm( adata, net, + if_gpu, ): dc.mlm(data=adata, net=net, tmin=3) dc_pv = adata.obsm["padj_mlm"] - adj = fdr_bh_axis1(dc_pv.values) + adj = fdr_bh_axis1(dc_pv.values, if_gpu=if_gpu) np.testing.assert_allclose( adj, sts.false_discovery_control(dc_pv.values, axis=1, method="bh") )