From bc5127333d7881c47dde6c36872b554d1dd8fee0 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 15:38:28 +0200 Subject: [PATCH 01/36] first iteration of refactor --- .gitignore | 2 + CMakeLists.txt | 1 + src/rapids_singlecell/_cuda/nb_types.h | 7 + .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 400 +++++++ .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 976 ++++++++++++++++++ .../_cuda/wilcoxon/wilcoxon.cu | 91 ++ .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 345 +++++++ .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 518 ++++++++++ .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 853 +++++++++++++++ .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 182 ++++ .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 104 ++ .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 861 +++++++++++++++ .../_cuda/wilcoxon/wilcoxon_sparse.cu | 292 ++++++ .../wilcoxon/wilcoxon_sparse_kernels.cuh | 651 ++++++++++++ .../tools/_rank_genes_groups/__init__.py | 182 +++- .../tools/_rank_genes_groups/_core.py | 340 ++++-- .../tools/_rank_genes_groups/_utils.py | 54 +- .../tools/_rank_genes_groups/_wilcoxon.py | 966 +++++++++++++++-- .../_rank_genes_groups/_wilcoxon_binned.py | 16 +- tests/test_rank_genes_groups_ttest.py | 10 +- tests/test_rank_genes_groups_wilcoxon.py | 401 ++++++- 21 files changed, 7054 insertions(+), 198 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh diff --git a/.gitignore b/.gitignore index c0e83438..6994e147 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,8 @@ coverage.xml .cursor/ .claude/ CLAUDE.md +.codex # tmp_scripts tmp_scripts/ +benchmarks/ diff --git a/CMakeLists.txt b/CMakeLists.txt index cacf9849..85d33e91 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,7 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 905e1e07..eb343815 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -42,6 +42,13 @@ using gpu_array = nb::ndarray; template using gpu_array_contig = nb::ndarray; +// Host (NumPy) array aliases +template +using host_array = nb::ndarray>; + +template +using host_array_2d = nb::ndarray>; + // Register bindings for both regular CUDA and managed-memory arrays. // Usage: // template diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index c89d913a..8b6af5f6 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -2,6 +2,27 @@ #include +__device__ __forceinline__ double wilcoxon_block_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + return v; + } + return 0.0; +} + /** * Kernel to compute tie correction factor for Wilcoxon test. * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied @@ -142,3 +163,382 @@ __global__ void average_rank_kernel(const double* __restrict__ sorted_vals, rk[si[i]] = avg_rank; } } + +/** + * OVO dense rank core. + * + * ref_sorted is F-order and sorted independently for every column. + * grp_data is F-order and contains test-group rows concatenated by + * grp_offsets. One block computes one (column, test-group) result. + * + * This intentionally centralizes the OVO math; host/device and CSR/CSC/dense + * paths only need to materialize bounded dense column batches that feed this + * kernel. + */ +__global__ void ovo_rank_dense_kernel(const float* __restrict__ ref_sorted, + const float* __restrict__ grp_data, + const int* __restrict__ grp_offsets, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, + int n_all_grp, int n_cols, int n_groups, + bool compute_tie_corr) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_data + (long long)col * n_all_grp + g_start; + + __shared__ double warp_buf[32]; + double local_rank = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + int n_lt_grp = 0; + int n_eq_grp = 0; + for (int j = 0; j < n_grp; ++j) { + float u = grp_col[j]; + n_lt_grp += (u < v); + n_eq_grp += (u == v); + } + + local_rank += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + double total_rank = wilcoxon_block_sum(local_rank, warp_buf); + if (threadIdx.x == 0) { + rank_sums[(size_t)grp * n_cols + col] = total_rank; + } + + if (!compute_tie_corr) return; + __syncthreads(); + + double local_tie = 0.0; + + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int count = lo - i; + for (int j = 0; j < n_grp; ++j) count += (grp_col[j] == v); + if (count > 1) { + double t = (double)count; + local_tie += t * t * t - t; + } + } + } + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + bool seen_in_group = false; + for (int j = 0; j < i; ++j) { + if (grp_col[j] == v) { + seen_in_group = true; + break; + } + } + if (seen_in_group) continue; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + if (lo < n_ref && ref_col[lo] == v) continue; + + int count = 0; + for (int j = 0; j < n_grp; ++j) count += (grp_col[j] == v); + if (count > 1) { + double t = (double)count; + local_tie += t * t * t - t; + } + } + + double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[(size_t)grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +__global__ void ovo_rank_presorted_kernel(const float* __restrict__ ref_sorted, + const float* __restrict__ grp_sorted, + const int* __restrict__ grp_offsets, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, + int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; + + __shared__ double warp_buf[32]; + double local_rank = 0.0; + + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + + int lo = ref_lb, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_rank += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + double total_rank = wilcoxon_block_sum(local_rank, warp_buf); + if (threadIdx.x == 0) { + rank_sums[(size_t)grp * n_cols + col] = total_rank; + } + + if (!compute_tie_corr) return; + __syncthreads(); + + double local_tie = 0.0; + int grp_lb_tie = 0, grp_ub_tie = 0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - i; + + lo = grp_lb_tie; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int lb = lo; + grp_lb_tie = lb; + + lo = (grp_ub_tie > lb) ? grp_ub_tie : lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_grp = lo - lb; + grp_ub_tie = lo; + + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + int ref_lb_tie = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + if (i == 0 || grp_col[i] != grp_col[i - 1]) { + float v = grp_col[i]; + int lo = ref_lb_tie, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + ref_lb_tie = lo; + if (lo < n_ref && ref_col[lo] == v) continue; + + lo = i + 1; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[(size_t)grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +/** + * OVR dense rank core. + * + * sorted_vals and sorter are F-order outputs of sorting each column of the + * current dense block. The kernel directly accumulates rank sums per group, + * avoiding a full ranks matrix and a group one-hot matrix multiply. + */ +__global__ void ovr_rank_dense_kernel(const float* __restrict__ sorted_vals, + const int* __restrict__ sorter, + const int* __restrict__ group_codes, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, + bool compute_tie_corr) { + int col = blockIdx.x; + if (col >= n_cols) return; + + const float* sv = sorted_vals + (long long)col * n_rows; + const int* si = sorter + (long long)col * n_rows; + + double local_tie = 0.0; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + float val = sv[i]; + + int lo = 0, hi = i; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + int tie_start = lo; + + lo = i; + hi = n_rows - 1; + while (lo < hi) { + int mid = lo + ((hi - lo + 1) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + int tie_end = lo; + double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; + + int row = si[i]; + int group = group_codes[row]; + if (group >= 0 && group < n_groups) { + atomicAdd(&rank_sums[(size_t)group * n_cols + col], avg_rank); + } + + if (compute_tie_corr && i == tie_end) { + double t = (double)(tie_end - tie_start + 1); + if (t > 1.0) local_tie += t * t * t - t; + } + } + + if (!compute_tie_corr) return; + + __shared__ double warp_buf[32]; + double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh new file mode 100644 index 00000000..5b4c0b8c --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -0,0 +1,976 @@ +#pragma once + +#include + +// ============================================================================ +// Warp reduction helper (sum doubles across block via warp_buf) +// ============================================================================ + +__device__ __forceinline__ double block_reduce_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v2 = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v2 += __shfl_down_sync(0xffffffff, v2, off); + return v2; // only lane 0 of warp 0 has the final result + } + return 0.0; +} + +// ============================================================================ +// Parallel tie correction — all threads collaborate. +// +// For each unique value in the combined sorted (ref, grp) arrays, accumulate +// t^3 - t where t = count of that value. Uses two passes: +// 1. Iterate unique values in ref_col, count in both arrays. +// 2. Iterate unique values in grp_col that do NOT appear in ref_col. +// +// Incremental binary search bounds exploit monotonicity within each thread's +// stride to reduce total search work. +// +// Caller must __syncthreads() before calling. warp_buf is reused for +// reduction (32 doubles, shared memory). +// ============================================================================ + +__device__ __forceinline__ void compute_tie_correction_parallel( + const float* ref_col, int n_ref, const float* grp_col, int n_grp, + double* warp_buf, double* out) { + double local_tie = 0.0; + + // Pass 1: unique values in ref_col + int grp_lb = 0, grp_ub = 0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - i; + + // Count in grp: incremental lower/upper bound + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int lb = lo; + grp_lb = lb; + + lo = (grp_ub > lb) ? grp_ub : lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_grp = lo - lb; + grp_ub = lo; + + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp_col that are absent from ref_col + int ref_lb = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + if (i == 0 || grp_col[i] != grp_col[i - 1]) { + float v = grp_col[i]; + + // Incremental lower_bound in ref + int lo = ref_lb, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + ref_lb = lo; + + if (lo >= n_ref || ref_col[lo] != v) { + // Value not in ref — count in grp only (upper_bound from i+1) + lo = i + 1; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + } + + // Block-wide reduction + double tie_sum = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + *out = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Batched rank sums — pre-sorted (binary search, no shared memory sort) +// Used by the OVO streaming pipeline in wilcoxon_streaming.cu. +// +// Incremental binary search: each thread carries forward lower/upper bound +// positions across loop iterations, exploiting the monotonicity of the +// sorted grp_col values within each thread's stride. +// ============================================================================ + +__global__ void batched_rank_sums_presorted_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, + const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le /*= 0*/) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // Size-gated dispatch (see ovo_fused_sort_rank_kernel for the contract). + if (n_grp <= skip_n_grp_le) return; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; + + // Incremental binary search bounds (advance monotonically per thread) + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + int lo, hi; + + // Lower bound in ref (from ref_lb) + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + // Upper bound in ref (from max(ref_ub, n_lt_ref)) + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + // Lower bound in grp (from grp_lb) + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + // Upper bound in grp (from max(grp_ub, n_lt_grp)) + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + compute_tie_correction_parallel(ref_col, n_ref, grp_col, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// ============================================================================ +// Tier 1 fused kernel: smem bitonic sort + binary search rank sums +// For small groups (< ~2K cells). No CUB, no global memory sort buffers. +// Grid: (n_cols, n_groups), Block: min(padded_grp_size, 512) +// Shared memory: padded_grp_size floats + 32 doubles (warp reduction) +// ============================================================================ + +__global__ void ovo_fused_sort_rank_kernel( + const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted + const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) + // unsorted + const int* __restrict__ grp_offsets, // (n_groups + 1,) + double* __restrict__ rank_sums, // (n_groups, n_cols) row-major + double* __restrict__ tie_corr, // (n_groups, n_cols) row-major + int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int padded_grp_size, int skip_n_grp_le /*= 0*/) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // Size-gated dispatch: when co-launched with the Tier 0 warp kernel we + // skip groups it's already handling. Each group owns its own + // rank_sums row, so the two kernels' writes never alias. + if (n_grp <= skip_n_grp_le) return; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // Shared memory: [padded_grp_size floats | 32 doubles for warp reduction] + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + padded_grp_size * sizeof(float)); + + // Load group data into shared memory, pad with +INF + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + for (int i = n_grp + threadIdx.x; i < padded_grp_size; i += blockDim.x) + grp_smem[i] = __int_as_float(0x7f800000); // +INF + __syncthreads(); + + // Bitonic sort in shared memory + for (int k = 2; k <= padded_grp_size; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = threadIdx.x; i < padded_grp_size; i += blockDim.x) { + int ixj = i ^ j; + if (ixj > i) { + bool asc = ((i & k) == 0); + float a = grp_smem[i], b = grp_smem[ixj]; + if (asc ? (a > b) : (a < b)) { + grp_smem[i] = b; + grp_smem[ixj] = a; + } + } + } + __syncthreads(); + } + } + + // Binary search each sorted grp element against sorted ref + // Incremental bounds: values are monotonic within each thread's stride + const float* ref_col = ref_sorted + (long long)col * n_ref; + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + int lo, hi; + + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + // Block reduction → write rank_sums + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + // Parallel tie correction (grp_smem is sorted shared memory) + compute_tie_correction_parallel(ref_col, n_ref, grp_smem, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// ============================================================================ +// Tier 2 helper: tie contribution of the sorted reference alone. +// One block per column. The medium unsorted-rank kernel uses this as a base +// and only adds group-only/overlap deltas from the unsorted group values. +// ============================================================================ + +__global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, + double* __restrict__ ref_tie_sums, int n_ref, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* ref_col = ref_sorted + (long long)col * n_ref; + + double local_tie = 0.0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) ref_tie_sums[col] = total; +} + +__global__ void ovo_small64_sort_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > TIER0_64_GROUP_THRESHOLD) return; + + __shared__ float grp_smem[TIER0_64_GROUP_THRESHOLD]; + __shared__ double warp_buf[WARP_REDUCE_BUF]; + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + const float POS_INF = __int_as_float(0x7f800000); + if (threadIdx.x < TIER0_64_GROUP_THRESHOLD) { + grp_smem[threadIdx.x] = + (threadIdx.x < n_grp) ? grp_col[threadIdx.x] : POS_INF; + } + __syncthreads(); + + for (int k = 2; k <= TIER0_64_GROUP_THRESHOLD; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + int i = threadIdx.x; + int ixj = i ^ j; + if (i < TIER0_64_GROUP_THRESHOLD && ixj > i) { + bool asc = ((i & k) == 0); + float a = grp_smem[i], b = grp_smem[ixj]; + if (asc ? (a > b) : (a < b)) { + grp_smem[i] = b; + grp_smem[ixj] = a; + } + } + __syncthreads(); + } + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + if (threadIdx.x < n_grp) { + float v = grp_smem[threadIdx.x]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + lo = 0; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && + (threadIdx.x == 0 || v != grp_smem[threadIdx.x - 1])) { + double combined = (double)(n_eq_ref + n_eq_grp); + if (combined > 1.0) { + local_tie_delta += combined * combined * combined - combined; + } + if (n_eq_ref > 1) { + double cr = (double)n_eq_ref; + local_tie_delta -= cr * cr * cr - cr; + } + } + } + + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + double tie_sum = ref_tie_sums[col] + tie_delta; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Tier 2 fused kernel: no-sort direct rank for medium groups. +// +// Avoids the smem bitonic sort for groups in (skip_n_grp_le, +// max_n_grp_le]. Ranks are computed from ref binary searches plus an +// in-group scan over unsorted shared values. Tie correction starts from +// ref_tie_sums[col] and adds only group-only / ref-overlap deltas. +// ============================================================================ + +__global__ void ovo_medium_unsorted_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le, int max_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > max_n_grp_le) return; + + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + max_n_grp_le * sizeof(float)); + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + __syncthreads(); + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + int n_lt_grp = 0; + int n_eq_grp = 0; + bool first_in_grp = true; + for (int j = 0; j < n_grp; ++j) { + float w = grp_smem[j]; + if (w < v) ++n_lt_grp; + if (w == v) { + ++n_eq_grp; + if (j < i) first_in_grp = false; + } + } + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && first_in_grp) { + double cg = (double)n_eq_grp; + double cr = (double)n_eq_ref; + double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0; + local_tie_delta += group_tie; + if (cr > 0.0) { + double combined = cr + cg; + double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0; + local_tie_delta += combined * combined * combined - combined - + ref_tie - group_tie; + } + } + } + + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + double tie_sum = ref_tie_sums[col] + tie_delta; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Warp-scoped tie correction for Tier 0. +// +// Sorted values live in a 32-lane register (one per lane, with unused lanes +// carrying +INF). Walks unique values via lane-step differentials and +// counts ties across the sorted ref column via binary search. All the +// sync is __syncwarp — no smem, no __syncthreads. +// ============================================================================ + +__device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, + int n_ref, float v_lane, + int n_grp, + unsigned int active_mask) { + int lane = threadIdx.x & 31; + double local_tie = 0.0; + + // Pass 1: for each unique value in ref_col, count occurrences in ref and + // in the sorted group (held in register v_lane across 32 lanes). + for (int base = 0; base < n_ref; base += 32) { + int i = base + lane; + bool in_ref_lane = (i < n_ref); + float v = in_ref_lane ? ref_col[i] : 0.0f; + bool is_first = in_ref_lane && ((i == 0) || (v != ref_col[i - 1])); + int cnt_ref = 0; + if (is_first) { + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + cnt_ref = lo - i; + } + + // Count in grp: look up how many lanes hold v_lane == v. All lanes + // execute the shuffle loop; only lanes owning a unique ref value use + // the result. + int cnt_grp = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + float vi = __shfl_sync(0xffffffff, v_lane, lane_i); + if (is_first && lane_i < n_grp && vi == v) ++cnt_grp; + } + + if (is_first) { + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp that are absent from ref. + // Walk lanes 0..n_grp-1; for each lane whose v differs from prev lane's, + // binary-search ref for v. If not present, count consecutive matching + // lanes (tie block). + if (lane < n_grp) { + float v = v_lane; + float prev_lane_v = + __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); + float v_prev = + (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF + bool first_in_grp = (lane == 0) || (v != v_prev); + bool in_ref = false; + if (first_in_grp) { + // Binary search in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + in_ref = (lo < n_ref) && (ref_col[lo] == v); + } + + // Count how many lanes ≥ this lane hold the same v. Keep the shuffle + // uniform across active lanes even though only unique, ref-absent + // group values consume the count. + int cnt = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + int src_lane = (lane_i < n_grp) ? lane_i : 0; + float vi = __shfl_sync(active_mask, v_lane, src_lane); + if (first_in_grp && !in_ref && lane_i >= lane && lane_i < n_grp && + vi == v) { + ++cnt; + } + } + if (first_in_grp && !in_ref && cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie += __shfl_down_sync(0xffffffff, local_tie, off); + return local_tie; // meaningful on lane 0. +} + +__device__ __forceinline__ double tier0_tie_delta_warp( + const float* ref_col, int n_ref, float v_lane, int n_grp, + unsigned int active_mask) { + int lane = threadIdx.x & 31; + double local_delta = 0.0; + + if (lane < n_grp) { + float v = v_lane; + float prev_lane_v = + __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); + float v_prev = + (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF + bool first_in_grp = (lane == 0) || (v != v_prev); + + int cnt_grp = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + int src_lane = (lane_i < n_grp) ? lane_i : 0; + float vi = __shfl_sync(active_mask, v_lane, src_lane); + if (lane_i < n_grp && vi == v) ++cnt_grp; + } + + if (first_in_grp) { + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int ref_lb = lo; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - ref_lb; + + double combined = (double)(cnt_ref + cnt_grp); + if (combined > 1.0) { + local_delta += combined * combined * combined - combined; + } + if (cnt_ref > 1) { + double cr = (double)cnt_ref; + local_delta -= cr * cr * cr - cr; + } + } + } + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_delta += __shfl_down_sync(0xffffffff, local_delta, off); + return local_delta; // meaningful on lane 0. +} + +// ============================================================================ +// Tier 0 fused kernel: warp-per-(col, group) pair, 8 warps packed per block. +// +// Each warp independently: +// 1. Loads ≤ 32 group values into a single register (one per lane, +// padded with +INF). +// 2. Bitonic-sorts via __shfl_xor_sync — no smem, no __syncthreads. +// 3. Binary-searches into sorted ref for each lane's value and +// accumulates the rank-sum term. +// 4. Warp-shuffle reduces to lane 0 and writes rank_sums / tie_corr. +// +// 8 (col, group) pairs per block cuts block count 8× vs the block-per-pair +// Tier 1, and the lack of __syncthreads / smem sort lets each warp run +// independently at full throughput. +// +// Grid: (n_cols, ceil(n_groups / 8)), Block: 256. +// ============================================================================ + +__global__ void ovo_warp_sort_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr) { + constexpr int WARPS_PER_BLOCK = 8; + int warp_id = threadIdx.x >> 5; + int lane = threadIdx.x & 31; + + int col = blockIdx.x; + int grp = blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // This kernel only handles groups that fit in a single warp (one value + // per lane). Larger groups are delegated to Tier 1/3 in a co-launched + // kernel; since each group owns its own row in rank_sums/tie_corr, the + // two kernels interlace into the output without conflict. + if (n_grp > TIER0_GROUP_THRESHOLD) return; + + if (n_grp == 0) { + if (lane == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // One value per lane, pad with +INF so sort pushes them to the end. + const float POS_INF = __int_as_float(0x7f800000); + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + float x = (lane < n_grp) ? grp_col[lane] : POS_INF; + unsigned int active_mask = __ballot_sync(0xffffffff, lane < n_grp); + + // Warp-shuffle bitonic sort (ascending) — 32 elements in registers. + for (int k = 1; k <= 16; k <<= 1) { + for (int j = k; j > 0; j >>= 1) { + float y = __shfl_xor_sync(0xffffffff, x, j); + bool asc = (((lane & (k << 1)) == 0)); + bool take_min = (((lane & j) == 0) == asc); + x = take_min ? fminf(x, y) : fmaxf(x, y); + } + } + + // After sort, x[lane] holds the lane-th smallest group value (lanes + // ≥ n_grp hold +INF). Binary-search each value into the sorted ref. + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + + if (lane < n_grp) { + float v = x; + // Lower bound in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + // Upper bound in ref. + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + // In-group counts: in the sorted warp-register x, count lanes < this + // one that hold strictly less, and lanes with equal value. + int n_lt_grp = 0; + int n_eq_grp_offset = 0; // tied lanes strictly before this one + int n_eq_grp_after = 1; // count self +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + if (lane_i >= n_grp) continue; + float vi = __shfl_sync(active_mask, v, lane_i); + if (lane_i < lane) { + if (vi < v) + ++n_lt_grp; + else if (vi == v) + ++n_eq_grp_offset; + } else if (lane_i > lane) { + if (vi == v) ++n_eq_grp_after; + } + } + int n_eq_grp_total = n_eq_grp_offset + n_eq_grp_after; + // Contribution: rank = n_lt_ref + n_lt_grp + (n_eq_ref + + // n_eq_grp_total + 1) / 2, but we sum per lane so each tie lane + // gets the same mid-rank. This matches the Tier 1 accumulation. + local_sum = (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp_total) + 1.0) / 2.0; + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + if (lane == 0) rank_sums[grp * n_cols + col] = local_sum; + + if (!compute_tie_corr) return; + + // Warp-scoped tie correction. + double tie_sum; + if (ref_tie_sums != nullptr) { + tie_sum = ref_tie_sums[col] + + tier0_tie_delta_warp(ref_col, n_ref, x, n_grp, active_mask); + } else { + tie_sum = tier0_tie_sum_warp(ref_col, n_ref, x, n_grp, active_mask); + } + if (lane == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index d25f7d0f..0ab5b26c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -8,6 +8,7 @@ using namespace nb::literals; // Constants for kernel launch configuration constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int OVO_THREADS_PER_BLOCK = 256; static inline int round_up_to_warp(int n) { int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; @@ -37,6 +38,43 @@ static inline void launch_average_rank(const double* sorted_vals, CUDA_CHECK_LAST_ERROR(average_rank_kernel); } +static inline void launch_ovo_rank_dense( + const float* ref_sorted, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, cudaStream_t stream) { + dim3 block(OVO_THREADS_PER_BLOCK); + dim3 grid(n_cols, n_groups); + ovo_rank_dense_kernel<<>>( + ref_sorted, grp_data, grp_offsets, rank_sums, tie_corr, n_ref, + n_all_grp, n_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_rank_dense_kernel); +} + +static inline void launch_ovo_rank_presorted( + const float* ref_sorted, const float* grp_sorted, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, cudaStream_t stream) { + dim3 block(OVO_THREADS_PER_BLOCK); + dim3 grid(n_cols, n_groups); + ovo_rank_presorted_kernel<<>>( + ref_sorted, grp_sorted, grp_offsets, rank_sums, tie_corr, n_ref, + n_all_grp, n_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_rank_presorted_kernel); +} + +static inline void launch_ovr_rank_dense( + const float* sorted_vals, const int* sorter, const int* group_codes, + double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, cudaStream_t stream) { + int threads_per_block = round_up_to_warp(n_rows); + dim3 block(threads_per_block); + dim3 grid(n_cols); + ovr_rank_dense_kernel<<>>( + sorted_vals, sorter, group_codes, rank_sums, tie_corr, n_rows, n_cols, + n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovr_rank_dense_kernel); +} + template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; @@ -65,6 +103,59 @@ void register_bindings(nb::module_& m) { }, "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "stream"_a = 0); + + m.def( + "ovo_rank_dense", + [](gpu_array_f ref_sorted, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + std::uintptr_t stream) { + launch_ovo_rank_dense( + ref_sorted.data(), grp_data.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, (cudaStream_t)stream); + }, + "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + + m.def( + "ovo_rank_presorted", + [](gpu_array_f ref_sorted, + gpu_array_f grp_sorted, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + std::uintptr_t stream) { + launch_ovo_rank_presorted( + ref_sorted.data(), grp_sorted.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, (cudaStream_t)stream); + }, + "ref_sorted"_a, "grp_sorted"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + + m.def( + "ovr_rank_dense", + [](gpu_array_f sorted_vals, + gpu_array_f sorter, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, std::uintptr_t stream) { + launch_ovr_rank_dense(sorted_vals.data(), sorter.data(), + group_codes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, + compute_tie_corr, (cudaStream_t)stream); + }, + "sorted_vals"_a, "sorter"_a, "group_codes"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "stream"_a = 0); } NB_MODULE(_wilcoxon_cuda, m) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh new file mode 100644 index 00000000..dd50d2cb --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -0,0 +1,345 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR + +constexpr int WARP_SIZE = 32; +constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int N_STREAMS = 4; +constexpr int SUB_BATCH_COLS = 64; +constexpr int BEGIN_BIT = 0; +constexpr int END_BIT = 32; +// Default thread-per-block for utility kernels (extract, gather, offsets, +// etc.). +constexpr int UTIL_BLOCK_SIZE = 256; +// Scratch slots for warp-level reduction (one slot per warp, 32 warps max). +constexpr int WARP_REDUCE_BUF = 32; +// Max group size for the super-fast "warp-per-(col,group)" fused kernel +// (Tier 0). Each warp sorts and ranks one (col, group) pair entirely in +// registers via warp-shuffle bitonic sort — no smem sort buffer, no +// __syncthreads(). Blocks pack 8 warps so block launch overhead is +// amortised 8× across (col, group) work items. This path is the fast +// route for per-celltype perturbation-style workloads where most test +// groups have only a few dozen cells. +constexpr int TIER0_GROUP_THRESHOLD = 32; +// Second small-group tier for perturbation workloads where most groups are +// slightly larger than one warp. Uses one compact shared-memory sort block per +// (column, group), avoiding the heavier Tier 2 in-group scan. +constexpr int TIER0_64_GROUP_THRESHOLD = 64; +// Medium-group cutoff for the unsorted direct-rank kernel. For perturbation +// workloads most groups sit below this range, where avoiding a full smem +// bitonic sort wins despite the O(n^2) in-group count. +constexpr int TIER2_GROUP_THRESHOLD = 512; +// Max group size for the fused smem-sort rank kernel (Tier 1 fast path). +// Beyond this, fall back to CUB segmented sort + binary-search rank kernel. +constexpr int TIER1_GROUP_THRESHOLD = 2500; +// Per-stream dense slab budget (float32 items). Dynamic sub-batching sizes +// each group's column batch so that (n_g × eff_sb_cols) ≤ this. Bigger = +// fewer kernel launches; smaller = less per-stream memory. 128M items × 4B = +// 512 MB per stream dense slab + same for sorted copy ≈ 1 GB / stream. +constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; + +// --------------------------------------------------------------------------- +// RAII guard for cudaHostRegister. Unregisters on scope exit even when an +// exception unwinds — prevents leaked host pinning on stream-sync failures. +// --------------------------------------------------------------------------- +struct HostRegisterGuard { + void* ptr = nullptr; + + HostRegisterGuard() = default; + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) { + if (p && bytes > 0) { + cudaError_t err = cudaHostRegister(p, bytes, flags); + if (err != cudaSuccess) { + // Already-registered memory is fine; anything else means the + // subsequent kernels would read garbage from an unmapped + // pointer, so surface the error immediately. + if (err == cudaErrorHostMemoryAlreadyRegistered) { + cudaGetLastError(); // clear sticky error flag + } else { + throw std::runtime_error( + std::string("cudaHostRegister failed (") + + std::to_string((size_t)bytes) + + " bytes, flags=" + std::to_string(flags) + + "): " + cudaGetErrorString(err)); + } + } else { + ptr = p; + } + } + } + ~HostRegisterGuard() { + if (ptr) cudaHostUnregister(ptr); + } + HostRegisterGuard(const HostRegisterGuard&) = delete; + HostRegisterGuard& operator=(const HostRegisterGuard&) = delete; + HostRegisterGuard(HostRegisterGuard&& other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; + } + HostRegisterGuard& operator=(HostRegisterGuard&& other) noexcept { + if (this != &other) { + if (ptr) cudaHostUnregister(ptr); + ptr = other.ptr; + other.ptr = nullptr; + } + return *this; + } +}; + +// --------------------------------------------------------------------------- +// Small allocation pool for temporary CUDA buffers. The previous PR used RMM +// here, but these sparse Wilcoxon kernels only need scoped scratch memory; +// using cudaMalloc keeps this module independent of an extra build-time +// dependency. +// --------------------------------------------------------------------------- +struct RmmPool { + std::vector bufs; + + ~RmmPool() { + for (void* ptr : bufs) { + if (ptr) cudaFree(ptr); + } + } + + template + T* alloc(size_t count) { + if (count == 0) count = 1; + void* ptr = nullptr; + cudaError_t err = cudaMalloc(&ptr, count * sizeof(T)); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("cudaMalloc failed in Wilcoxon scratch pool: ") + + cudaGetErrorString(err)); + } + bufs.push_back(ptr); + return static_cast(ptr); + } +}; + +struct ScopedCudaBuffer { + void* ptr = nullptr; + + explicit ScopedCudaBuffer(size_t bytes) { + if (bytes == 0) bytes = 1; + cudaError_t err = cudaMalloc(&ptr, bytes); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("cudaMalloc failed in Wilcoxon scoped buffer: ") + + cudaGetErrorString(err)); + } + } + + ~ScopedCudaBuffer() { + if (ptr) cudaFree(ptr); + } + + void* data() { + return ptr; + } + + ScopedCudaBuffer(const ScopedCudaBuffer&) = delete; + ScopedCudaBuffer& operator=(const ScopedCudaBuffer&) = delete; +}; + +static inline int round_up_to_warp(int n) { + int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; +} + +/** Fill linear segment offsets [0, stride, 2*stride, ..., n_segments*stride] + * on-device. One thread per output slot. */ +__global__ void fill_linear_offsets_kernel(int* __restrict__ out, + int n_segments, int stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i <= n_segments) out[i] = i * stride; +} + +/** Fill per-row stats codes for a pack of K groups. + * Given pack_grp_offsets (size K+1, relative to pack start), write + * stats_codes[r] = base_slot + group_idx_of_row_r for r in [0, pack_n_rows). + * Binary search within the K+1 offsets. */ +__global__ void fill_pack_stats_codes_kernel( + const int* __restrict__ pack_grp_offsets, int* __restrict__ stats_codes, + int K, int base_slot) { + int r = blockIdx.x * blockDim.x + threadIdx.x; + int pack_n_rows = pack_grp_offsets[K]; + if (r >= pack_n_rows) return; + int lo = 0, hi = K; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (pack_grp_offsets[m + 1] <= r) + lo = m + 1; + else + hi = m; + } + stats_codes[r] = base_slot + lo; +} + +/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. + * Grid-strided: supports arbitrary `count` (no single-block thread limit). + * Templated so that 64-bit global indptrs can produce 32-bit pack-local + * indptrs (per-pack nnz always fits in int32 thanks to the memory budget). + */ +template +__global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, + IdxOut* __restrict__ out, int col, + int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); +} + +/** Fused gather + cast-to-float32 + stats accumulation, reading from mapped + * pinned host memory. Block-per-row; threads in the block cooperate on the + * row's nnz. Each nnz is read from host over PCIe exactly once — no + * intermediate native-dtype GPU buffer, no second GPU pass. + * + * h_data / h_indices: device-accessible pointers into mapped pinned host + * memory (cudaHostRegisterMapped). + * d_indptr_full: full-matrix indptr on device. + * d_row_ids: rows to gather (size n_target_rows). + * d_out_indptr: pre-computed compacted indptr, size n_target_rows+1 with + * out_indptr[i+1] - out_indptr[i] equal to the source row's + * nnz. + * + * Slot dispatch: + * d_stats_codes != nullptr → slot = d_stats_codes[r]; otherwise slot = + * fixed_slot (used for the Ref phase where every row maps to the same + * slot). slot ∉ [0, n_groups_stats) skips accumulation. + */ +template +__global__ void csr_gather_cast_accumulate_mapped_kernel( + const InT* __restrict__ h_data, const IndexT* __restrict__ h_indices, + const IndptrT* __restrict__ d_indptr_full, + const int* __restrict__ d_row_ids, const int* __restrict__ d_out_indptr, + const int* __restrict__ d_stats_codes, int fixed_slot, + float* __restrict__ d_out_data_f32, int* __restrict__ d_out_indices, + double* __restrict__ group_sums, double* __restrict__ group_sq_sums, + double* __restrict__ group_nnz, int n_target_rows, int n_cols, + int n_groups_stats, bool compute_sums, bool compute_sq_sums, + bool compute_nnz) { + int r = blockIdx.x; + if (r >= n_target_rows) return; + int src_row = d_row_ids[r]; + IndptrT rs = d_indptr_full[src_row]; + IndptrT re = d_indptr_full[src_row + 1]; + int row_nnz = (int)(re - rs); + int ds = d_out_indptr[r]; + int slot = (d_stats_codes != nullptr) ? d_stats_codes[r] : fixed_slot; + bool accumulate = (slot >= 0 && slot < n_groups_stats); + for (int i = threadIdx.x; i < row_nnz; i += blockDim.x) { + InT v_in = h_data[rs + i]; + int c = (int)h_indices[rs + i]; + double v = (double)v_in; + d_out_data_f32[ds + i] = (float)v_in; + d_out_indices[ds + i] = c; + if (accumulate) { + if (compute_sums) { + atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); + } + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)slot * n_cols + c], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); + } + } + } +} + +/** Fill linear segment offsets [0, stride, 2*stride, ...] on device. + * Runs on the supplied stream so it doesn't serialize multi-stream pipelines. + */ +static inline void upload_linear_offsets(int* d_offsets, int n_segments, + int stride, cudaStream_t stream) { + int count = n_segments + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_linear_offsets_kernel<<>>( + d_offsets, n_segments, stride); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); +} + +// ============================================================================ +// CSR → dense F-order extraction (templated on data type) +// ============================================================================ + +template +__global__ void csr_extract_dense_kernel(const T* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_ids, + T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + tid] = data[p]; + } +} + +template +__global__ void csr_extract_dense_identity_rows_kernel( + const T* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int row = blockIdx.x; + if (row >= n_target) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo + threadIdx.x; p < re; p += blockDim.x) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + row] = data[p]; + } +} + +template +__global__ void csr_extract_dense_identity_rows_unsorted_kernel( + const T* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int row = blockIdx.x; + if (row >= n_target) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + for (int p = rs + threadIdx.x; p < re; p += blockDim.x) { + int c = indices[p]; + if (c >= col_start && c < col_stop) { + out[(long long)(c - col_start) * n_target + row] = data[p]; + } + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh new file mode 100644 index 00000000..7ad20b01 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -0,0 +1,518 @@ +#pragma once + +/** + * CSR-direct OVO streaming pipeline. + * + * One C++ call does everything. Reference rows are extracted and sorted once + * across all columns, then each group sub-batch ranks against that cached + * reference slice. This mirrors the fast host-CSR path and avoids redoing the + * reference dense extraction + segmented sort for every column sub-batch. + */ +static void ovo_streaming_csr_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + size_t max_ref_cols = 2147483647LL / (size_t)n_ref; + if (max_ref_cols == 0) { + throw std::runtime_error( + "OVO device CSR reference group exceeds CUB int item limit"); + } + int ref_cache_cols = std::min(n_cols, (int)max_ref_cols); + size_t free_bytes = 0; + size_t total_bytes = 0; + if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess) { + size_t bytes_per_col = (size_t)n_ref * sizeof(float) * 2; + size_t target_bytes = free_bytes / 3; + if (bytes_per_col > 0 && target_bytes >= bytes_per_col) { + size_t mem_cols = target_bytes / bytes_per_col; + if (mem_cols > 0 && mem_cols < (size_t)ref_cache_cols) { + ref_cache_cols = (int)mem_cols; + } + } + } + if (ref_cache_cols < 1) ref_cache_cols = 1; + + RmmPool pool; + + size_t cub_temp_bytes = 0; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = cub_grp_bytes; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* grp_dense; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].cub_temp = + needs_tier3 ? pool.alloc(cub_temp_bytes) : nullptr; + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_extract = round_up_to_warp(std::max(n_ref, n_all_grp)); + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + for (int cache_col = 0; cache_col < n_cols; cache_col += ref_cache_cols) { + int cache_cols = std::min(ref_cache_cols, n_cols - cache_col); + size_t cache_ref_items = (size_t)n_ref * cache_cols; + + ScopedCudaBuffer ref_dense_buf(cache_ref_items * sizeof(float)); + ScopedCudaBuffer ref_sorted_buf(cache_ref_items * sizeof(float)); + ScopedCudaBuffer ref_seg_offsets_buf((size_t)(cache_cols + 1) * + sizeof(int)); + float* d_ref_dense = (float*)ref_dense_buf.data(); + float* d_ref_sorted = (float*)ref_sorted_buf.data(); + int* d_ref_seg_offsets = (int*)ref_seg_offsets_buf.data(); + + cudaMemsetAsync(d_ref_dense, 0, cache_ref_items * sizeof(float)); + int tpb_ref_extract = round_up_to_warp(n_ref); + int ref_blk = (n_ref + tpb_ref_extract - 1) / tpb_ref_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, ref_row_ids, d_ref_dense, n_ref, + cache_col, cache_col + cache_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + + upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, 0); + + size_t ref_cub_bytes = 0; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_bytes, fk, fk, (int)cache_ref_items, cache_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); + size_t ref_temp = ref_cub_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + ref_cub_temp_buf.data(), ref_temp, d_ref_dense, d_ref_sorted, + (int)cache_ref_items, cache_cols, d_ref_seg_offsets, + d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT); + cudaDeviceSynchronize(); + + int col = cache_col; + int cache_stop = cache_col + cache_cols; + int batch_idx = 0; + while (col < cache_stop) { + int sb_cols = std::min(sub_batch_cols, cache_stop - col); + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + const float* ref_sub = + d_ref_sorted + (size_t)(col - cache_col) * n_ref; + + cudaMemsetAsync(buf.grp_dense, 0, + sb_grp_items_actual * sizeof(float), stream); + { + int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, grp_row_ids, + buf.grp_dense, n_all_grp, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(ref_sub, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium( + ref_sub, buf.grp_dense, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = + t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = + (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<< + blk, UTIL_BLOCK_SIZE, 0, stream>>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR( + build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in OVO device CSR streaming: ") + + cudaGetErrorString(err)); + } + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * CSC-direct OVO streaming pipeline. + * + * Like the CSR variant, but extracts rows via lookup maps so it can operate on + * native CSC input without converting the whole matrix. + */ +static void ovo_streaming_csc_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, + n_ref, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, + n_all_grp, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + if (run_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, buf.sub_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, + stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in OVO device CSC streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh new file mode 100644 index 00000000..feb86e57 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -0,0 +1,853 @@ +#pragma once + +/** + * Host-streaming CSC OVO pipeline. + * + * CSC arrays live on host. Only the sparse data for each sub-batch of + * columns is transferred to GPU. Row maps + group offsets are uploaded once. + * Results are written back to host per sub-batch. + */ +template +static void ovo_streaming_csc_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_ref_row_map, const int* h_grp_row_map, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + bool compute_sq_sums, bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + // Max nnz across any sub-batch for sparse transfer buffer sizing + size_t max_nnz = 0; + for (int c = 0; c < n_cols; c += sub_batch_cols) { + int sb = std::min(sub_batch_cols, n_cols - c); + size_t nnz = (size_t)(h_indptr[c + sb] - h_indptr[c]); + if (nnz > max_nnz) max_nnz = nnz; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // GPU copies of row maps + group offsets + stats codes (uploaded once) + int* d_ref_row_map = pool.alloc(n_rows); + int* d_grp_row_map = pool.alloc(n_rows); + int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); + int* d_sort_group_ids = nullptr; + cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + IndexT* d_sparse_indices; + int* d_indptr; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = pool.alloc( + compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); + bufs[s].d_group_nnz = pool.alloc( + compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_sq_sums, compute_nnz, cast_use_gmem); + + // Pin only the sparse input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_actual = n_ref * sb_cols; + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: sparse data for this column range (native dtype) ---- + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + size_t nnz = (size_t)(ptr_end - ptr_start); + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_indptr, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_sq_sums, + buf.d_group_nnz, sb_cols, n_groups_stats, compute_sq_sums, + compute_nnz, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + + // ---- Extract ref from CSC via row_map, sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_ref_row_map, buf.ref_dense, n_ref, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_grp_row_map, buf.grp_dense, n_all_grp, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + // ---- Tier dispatch: sort grp + rank ---- + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + if (run_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, + stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + d_grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host CSR OVO pipeline — zero-copy mapped full-CSR with GPU-side row gather. + * + * Setup: pin the full host CSR with cudaHostRegisterMapped, upload the full + * indptr (small) + row_ids + pre-computed compacted indptrs. Each pack + * gathers only its rows over PCIe via a UVA kernel — the full matrix is never + * transferred to GPU. + * + * Phase 1 (Ref): fused gather + cast + stats over ref rows; segmented sort + * to d_ref_sorted (cached for the whole run). + * Phase 2 (per pack, round-robin across N_STREAMS): + * 1. rebase per-pack output indptr from the pre-uploaded global compacted + * indptr. + * 2. rebase per-pack group offsets + build per-row stats codes. + * 3. csr_gather_cast_accumulate_mapped_kernel — one PCIe pass, writes + * compacted f32 data + indices and accumulates per-group stats. + * 4. Per sub-batch: extract dense → sort → rank vs ref_sorted → scatter. + * + * Memory: d_ref_sorted (n_ref × n_cols × 4B) + N_STREAMS pack buffers sized + * for max_pack_rows × sb_cols (dense) and max_pack_nnz (compacted CSR). + * Full CSR stays on host (pinned-mapped). + */ +template +static void ovo_streaming_csr_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_full_rows, const int* h_ref_row_ids, int n_ref, + const int* h_grp_row_ids, const int* h_grp_offsets, int n_all_grp, + int n_test, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int n_cols, + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, + bool compute_nnz, bool compute_sums, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; + + // ---- Pre-compute compacted indptrs on host (O(n_ref + n_all_grp)) ---- + // Use IndptrT for the global compacted indptr because the grp side can + // exceed 2^31 nnz on very large / dense matrices. Ref always fits in + // int32 since n_ref × n_cols ≪ 2B; keeping int32 there matches the + // downstream CUB segmented-sort temp sizing. + std::vector h_ref_indptr_compact(n_ref + 1); + h_ref_indptr_compact[0] = 0; + for (int i = 0; i < n_ref; i++) { + int r = h_ref_row_ids[i]; + int nnz_i = (int)(h_indptr[r + 1] - h_indptr[r]); + h_ref_indptr_compact[i + 1] = h_ref_indptr_compact[i] + nnz_i; + } + int ref_nnz = h_ref_indptr_compact[n_ref]; + + // grp: compacted indptr over concatenated test-group rows (IndptrT). + std::vector h_grp_indptr_compact(n_all_grp + 1); + h_grp_indptr_compact[0] = 0; + for (int i = 0; i < n_all_grp; i++) { + int r = h_grp_row_ids[i]; + IndptrT nnz_i = h_indptr[r + 1] - h_indptr[r]; + h_grp_indptr_compact[i + 1] = h_grp_indptr_compact[i] + nnz_i; + } + + // ---- Build packs (same rule as grp_impl, but uses compacted indptr) ---- + struct Pack { + int first; + int end; + int n_rows; + size_t nnz; + int sb_cols; + }; + std::vector packs; + int max_pack_rows = 0; + size_t max_pack_nnz = 0; + int max_pack_K = 0; + int max_pack_items = 0; + int max_pack_sb_cols = sub_batch_cols; + { + int target_packs = N_STREAMS; + int target_rows = (n_all_grp + target_packs - 1) / target_packs; + if (target_rows < 1) target_rows = 1; + size_t budget_cap_rows = + GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; + if ((size_t)target_rows > budget_cap_rows) + target_rows = (int)budget_cap_rows; + + int cur_first = 0; + int cur_rows = 0; + size_t cur_nnz = 0; + for (int g = 0; g < n_test; g++) { + int n_g = h_grp_offsets[g + 1] - h_grp_offsets[g]; + size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - + h_grp_indptr_compact[h_grp_offsets[g]]); + int new_rows = cur_rows + n_g; + bool can_add = (cur_rows == 0) || (new_rows <= target_rows); + if (!can_add) { + size_t sb_size = + std::min((size_t)n_cols, + GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, g, cur_rows, cur_nnz, (int)sb_size}); + cur_first = g; + cur_rows = n_g; + cur_nnz = nnz_g; + } else { + cur_rows = new_rows; + cur_nnz += nnz_g; + } + } + if (cur_rows > 0) { + size_t sb_size = std::min( + (size_t)n_cols, GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, n_test, cur_rows, cur_nnz, (int)sb_size}); + } + } + for (const Pack& pk : packs) { + int K = pk.end - pk.first; + if (pk.n_rows > max_pack_rows) max_pack_rows = pk.n_rows; + if (pk.nnz > max_pack_nnz) max_pack_nnz = pk.nnz; + if (K > max_pack_K) max_pack_K = K; + int pack_items = pk.n_rows * pk.sb_cols; + if (pack_items > max_pack_items) max_pack_items = pack_items; + if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; + } + int max_group_rows = max_pack_rows; + size_t max_sub_items = (size_t)max_pack_items; + if (max_pack_rows == 0) return; + + RmmPool pool; + + // Zero stats outputs. + if (compute_sums) { + cudaMemsetAsync(d_group_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + + // ---- Pin full host data + indices as MAPPED (zero-copy accessible) ---- + size_t full_nnz = (size_t)h_indptr[n_full_rows]; + HostRegisterGuard _pin_data(const_cast(h_data), + full_nnz * sizeof(InT), cudaHostRegisterMapped); + HostRegisterGuard _pin_indices(const_cast(h_indices), + full_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + + // Get device-accessible pointers (UVA makes these equal to host ptrs on + // Linux x86-64, but the API is the safe/portable way). + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (full_nnz > 0) { + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + // ---- Upload full indptr (keep native IndptrT — can exceed int32) ---- + IndptrT* d_indptr_full = pool.alloc(n_full_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_full_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + // ---- Upload row_ids + compacted indptrs + group boundaries ---- + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + IndptrT* d_grp_indptr_compact = pool.alloc(n_all_grp + 1); + int* d_grp_offsets_full = pool.alloc(n_test + 1); + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_indptr_compact, h_grp_indptr_compact.data(), + (n_all_grp + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets_full, h_grp_offsets, (n_test + 1) * sizeof(int), + cudaMemcpyHostToDevice); + + // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- + float* d_ref_sorted = pool.alloc((size_t)n_ref * n_cols); + { + ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); + ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); + ScopedCudaBuffer ref_indptr_buf((n_ref + 1) * sizeof(int)); + ScopedCudaBuffer ref_dense_buf((size_t)n_ref * n_cols * sizeof(float)); + ScopedCudaBuffer ref_seg_buf((n_cols + 1) * sizeof(int)); + + float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); + int* d_ref_indices = (int*)ref_indices_buf.data(); + int* d_ref_indptr = (int*)ref_indptr_buf.data(); + float* d_ref_dense = (float*)ref_dense_buf.data(); + int* d_ref_seg = (int*)ref_seg_buf.data(); + + // Upload ref compacted indptr + cudaMemcpy(d_ref_indptr, h_ref_indptr_compact.data(), + (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); + + // Fused gather + cast + stats for ref (fixed slot = n_test). One + // pass over PCIe, no intermediate native-dtype GPU buffer. + if (n_ref > 0 && ref_nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, + d_ref_indptr, /*d_stats_codes=*/nullptr, + /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, + d_group_sums, d_group_sq_sums, d_group_nnz, n_ref, n_cols, + n_groups_stats, compute_sums, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Extract ref dense (F-order) from compacted CSR. + cudaMemsetAsync(d_ref_dense, 0, (size_t)n_ref * n_cols * sizeof(float)); + { + csr_extract_dense_identity_rows_unsorted_kernel + <<>>(d_ref_data_f32, d_ref_indices, + d_ref_indptr, d_ref_dense, n_ref, + 0, n_cols); + CUDA_CHECK_LAST_ERROR( + csr_extract_dense_identity_rows_unsorted_kernel); + } + + // Segmented sort ref_dense by column → ref_sorted + size_t ref_cub_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_bytes, fk, fk, (int)((size_t)n_ref * n_cols), + n_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + } + ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); + upload_linear_offsets(d_ref_seg, n_cols, n_ref, 0); + size_t temp = ref_cub_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, + (int)((size_t)n_ref * n_cols), n_cols, d_ref_seg, d_ref_seg + 1, + BEGIN_BIT, END_BIT); + cudaDeviceSynchronize(); + } // ref scratch drops here + + // ---- Phase 2: Per-pack streaming ---- + auto t1 = make_tier1_config(h_grp_offsets, n_test); + bool may_need_cub = (t1.max_grp_size > TIER1_GROUP_THRESHOLD); + + constexpr int MAX_GROUP_STREAMS = 4; + int n_streams = MAX_GROUP_STREAMS; + if (n_test < n_streams) n_streams = n_test; + if (n_streams < 1) n_streams = 1; + if ((int)packs.size() < n_streams) n_streams = (int)packs.size(); + if (n_streams < 1) n_streams = 1; + + size_t cub_grp_bytes = 0; + if (may_need_cub && max_sub_items > 0) { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + int max_segments = max_pack_K * max_pack_sb_cols; + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)max_sub_items, max_segments, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + struct StreamBuf { + float* d_grp_data_f32; + int* d_grp_indices; + int* d_grp_indptr; + int* d_pack_grp_offsets; + int* d_pack_stats_codes; + float* d_grp_dense; + float* d_grp_sorted; + double* d_ref_tie_sums; + int* d_sort_group_ids; + int* d_grp_seg_offsets; + int* d_grp_seg_ends; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + int max_pack_kernel_seg = max_pack_K * max_pack_sb_cols; + for (int s = 0; s < n_streams; s++) { + bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indptr = pool.alloc(max_pack_rows + 1); + bufs[s].d_pack_grp_offsets = pool.alloc(max_pack_K + 1); + bufs[s].d_pack_stats_codes = pool.alloc(max_pack_rows); + bufs[s].d_grp_dense = pool.alloc(max_sub_items); + bufs[s].d_ref_tie_sums = pool.alloc(max_pack_sb_cols); + bufs[s].d_rank_sums = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + if (may_need_cub) { + bufs[s].d_grp_sorted = pool.alloc(max_sub_items); + bufs[s].d_sort_group_ids = pool.alloc(max_pack_K); + bufs[s].d_grp_seg_offsets = pool.alloc(max_pack_kernel_seg); + bufs[s].d_grp_seg_ends = pool.alloc(max_pack_kernel_seg); + bufs[s].cub_temp = pool.alloc(cub_grp_bytes); + } else { + bufs[s].d_grp_sorted = nullptr; + bufs[s].d_sort_group_ids = nullptr; + bufs[s].d_grp_seg_offsets = nullptr; + bufs[s].d_grp_seg_ends = nullptr; + bufs[s].cub_temp = nullptr; + } + } + + cudaDeviceSynchronize(); // ensure Phase 1 done before Phase 2 streams + + for (int p = 0; p < (int)packs.size(); p++) { + const Pack& pack = packs[p]; + int K = pack.end - pack.first; + if (K == 0 || pack.n_rows == 0) continue; + Tier1Config pack_t1 = make_tier1_config(h_grp_offsets + pack.first, K); + int pack_tpb_rank = round_up_to_warp( + std::min(pack_t1.max_grp_size, MAX_THREADS_PER_BLOCK)); + bool pack_has_above_t2 = pack_t1.max_grp_size > TIER2_GROUP_THRESHOLD; + int pack_tier3_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : TIER0_GROUP_THRESHOLD; + std::vector h_sort_group_ids; + int pack_n_sort_groups = K; + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + h_sort_group_ids = make_sort_group_ids(h_grp_offsets + pack.first, + K, pack_tier3_skip_le); + pack_n_sort_groups = (int)h_sort_group_ids.size(); + } + + int s = p % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + cudaMemcpyAsync(buf.d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + int row_start = h_grp_offsets[pack.first]; + int pack_rows = pack.n_rows; + int pack_sb = pack.sb_cols; + + // Rebase pack's output indptr from pre-uploaded global compacted indptr + // (IndptrT → int32: pack nnz is bounded by GROUP_DENSE_BUDGET so fits). + { + int count = pack_rows + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel + <<>>( + d_grp_indptr_compact, buf.d_grp_indptr, row_start, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Build per-pack group offsets on GPU (on this stream) — needed to + // compute stats codes before the fused gather kernel can run. + { + int count = K + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + d_grp_offsets_full, buf.d_pack_grp_offsets, pack.first, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Fill per-row stats codes for this pack + { + int blk = (pack_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_pack_stats_codes_kernel<<>>( + buf.d_pack_grp_offsets, buf.d_pack_stats_codes, K, pack.first); + CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); + } + + // Fused gather + cast + stats for the pack. One pass over PCIe + // (reads mapped host via UVA), no intermediate native-dtype GPU + // buffer, writes f32 + indices + atomics. + if (pack.nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, + d_grp_row_ids + row_start, buf.d_grp_indptr, + buf.d_pack_stats_codes, /*fixed_slot=*/-1, + buf.d_grp_data_f32, buf.d_grp_indices, d_group_sums, + d_group_sq_sums, d_group_nnz, pack_rows, n_cols, + n_groups_stats, compute_sums, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Per col sub-batch + int col = 0; + while (col < n_cols) { + int sb_cols = std::min(pack_sb, n_cols - col); + int sb_items = pack_rows * sb_cols; + + cudaMemsetAsync(buf.d_grp_dense, 0, sb_items * sizeof(float), + stream); + csr_extract_dense_identity_rows_unsorted_kernel + <<>>( + buf.d_grp_data_f32, buf.d_grp_indices, buf.d_grp_indptr, + buf.d_grp_dense, pack_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR( + csr_extract_dense_identity_rows_unsorted_kernel); + + const float* ref_sub = d_ref_sorted + (size_t)col * n_ref; + + int skip_le = 0; + bool run_tier0 = pack_t1.use_tier0; + bool run_tier0_64 = pack_t1.any_tier0_64; + bool run_tier2 = pack_t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.d_ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, n_ref, pack_rows, sb_cols, K, + compute_tie_corr, stream); + if (pack_t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, + pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); + if (pack_t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, + pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (pack_has_above_t2 && pack_t1.use_tier1) { + dim3 grid(sb_cols, K); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, pack_t1.padded_grp_size, + upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (pack_has_above_t2) { + int n_seg = pack_n_sort_groups * sb_cols; + { + int blk = (n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<< + blk, UTIL_BLOCK_SIZE, 0, stream>>>( + buf.d_pack_grp_offsets, buf.d_sort_group_ids, + buf.d_grp_seg_offsets, buf.d_grp_seg_ends, pack_rows, + pack_n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR( + build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_grp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.d_grp_dense, buf.d_grp_sorted, + sb_items, n_seg, buf.d_grp_seg_offsets, + buf.d_grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + dim3 grid(sb_cols, K); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.d_grp_sorted, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + + cudaMemcpy2DAsync(d_rank_sums + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_rank_sums, + sb_cols * sizeof(double), + sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync( + d_tie_corr + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_tie_corr, + sb_cols * sizeof(double), sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in ovo csr host streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh new file mode 100644 index 00000000..afac20f2 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -0,0 +1,182 @@ +#pragma once + +/** + * Build CUB segmented-sort ranges only for groups that Tier 3 will rank. + * Group ids are relative to grp_offsets, and ranges still point into the + * original dense group layout so the presorted rank kernel can read from the + * normal per-group positions. + */ +__global__ void build_tier3_seg_begin_end_offsets_kernel( + const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, + int* __restrict__ begins, int* __restrict__ ends, int n_all_grp, + int n_sort_groups, int sb_cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = sb_cols * n_sort_groups; + if (idx >= total) return; + + int c = idx / n_sort_groups; + int local = idx % n_sort_groups; + int g = group_ids[local]; + int base = c * n_all_grp; + begins[idx] = base + grp_offsets[g]; + ends[idx] = base + grp_offsets[g + 1]; +} + +/** + * Extract specific rows from CSC into dense F-order, using a row lookup map. + * row_map[original_row] = output_row_index (or -1 to skip). + * One block per column, threads scatter matching nonzeros. + * Output must be pre-zeroed. + */ +template +__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, + const IndexT* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_map, + float* __restrict__ out, int n_target, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + int start = indptr[col]; + int end = indptr[col + 1]; + + for (int p = start + threadIdx.x; p < end; p += blockDim.x) { + int out_row = row_map[(int)indices[p]]; + if (out_row >= 0) { + out[(long long)col_local * n_target + out_row] = data[p]; + } + } +} + +static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { + size_t bytes = 0; + auto* dk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, + n_segments, doff, doff + 1, 0, 32); + return bytes; +} + +/** + * Tier 1 dispatch: when the largest group fits in shared memory, a fused + * bitonic-sort + binary-search kernel handles the whole group per block. + * Otherwise we fall back to CUB segmented sort plus the pre-sorted rank + * kernel. This struct bundles the sizing knobs derived from the host-side + * group offsets so each streaming impl can drop a 15-line prep block. + */ +struct Tier1Config { + int max_grp_size = 0; + int min_grp_size = 0; + bool use_tier0 = + false; // any group fits in one warp (≤ TIER0_GROUP_THRESHOLD) + bool use_tier1 = + false; // any group needs > tier0 but fits in tier1 smem sort + bool any_above_t0 = + false; // at least one group exceeds TIER0_GROUP_THRESHOLD + bool any_tier0_64 = false; // any group needs Tier 0.5: (T0, T0_64] + bool any_tier2 = false; // any group needs Tier 2: (T0_64, T2] + bool any_above_t2 = + false; // at least one group exceeds TIER2_GROUP_THRESHOLD + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; +}; + +static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { + Tier1Config c; + c.min_grp_size = INT_MAX; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > c.max_grp_size) c.max_grp_size = sz; + if (sz < c.min_grp_size) c.min_grp_size = sz; + if (sz > TIER0_GROUP_THRESHOLD && sz <= TIER0_64_GROUP_THRESHOLD) { + c.any_tier0_64 = true; + } + if (sz > TIER0_64_GROUP_THRESHOLD && sz <= TIER2_GROUP_THRESHOLD) { + c.any_tier2 = true; + } + if (sz > TIER2_GROUP_THRESHOLD) c.any_above_t2 = true; + } + if (n_groups == 0) c.min_grp_size = 0; + + // use_tier0: Tier 0 kernel is worth running (at least one group small + // enough to benefit from the warp path). + c.use_tier0 = (c.min_grp_size <= TIER0_GROUP_THRESHOLD); + // any_above_t0: at least one group needs a non-Tier-0 kernel. + c.any_above_t0 = (c.max_grp_size > TIER0_GROUP_THRESHOLD); + // use_tier1: the fused smem-sort fast path (for groups > T0 but ≤ T1). + c.use_tier1 = c.any_above_t0 && (c.max_grp_size <= TIER1_GROUP_THRESHOLD); + if (c.use_tier1) { + c.padded_grp_size = 1; + while (c.padded_grp_size < c.max_grp_size) c.padded_grp_size <<= 1; + c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); + c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + } + return c; +} + +static std::vector make_sort_group_ids(const int* h_grp_offsets, + int n_groups, int skip_n_grp_le) { + std::vector ids; + ids.reserve(n_groups); + for (int g = 0; g < n_groups; ++g) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (skip_n_grp_le > 0 && sz <= skip_n_grp_le) continue; + ids.push_back(g); + } + return ids; +} + +// Tier 0 kernel launcher: 8 warps × 32 threads per block, one (col, group) +// pair per warp. grid.y covers ceil(K/8) pair rows. +static inline void launch_tier0(const float* ref_sorted, const float* grp_dense, + const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int sb_cols, int K, bool compute_tie_corr, + cudaStream_t stream) { + constexpr int WARPS_PER_BLOCK = 8; + dim3 grid(sb_cols, (K + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK); + ovo_warp_sort_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_warp_sort_rank_kernel); +} + +static inline void launch_ref_tie_sums(const float* ref_sorted, + double* ref_tie_sums, int n_ref, + int sb_cols, cudaStream_t stream) { + ref_tie_sum_kernel<<>>( + ref_sorted, ref_tie_sums, n_ref, sb_cols); + CUDA_CHECK_LAST_ERROR(ref_tie_sum_kernel); +} + +static inline void launch_tier0_64( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + dim3 grid(sb_cols, K); + ovo_small64_sort_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le); + CUDA_CHECK_LAST_ERROR(ovo_small64_sort_rank_kernel); +} + +static inline void launch_tier2_medium( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + constexpr int tpb = 256; + size_t smem = (size_t)TIER2_GROUP_THRESHOLD * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + dim3 grid(sb_cols, K); + ovo_medium_unsorted_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le, + TIER2_GROUP_THRESHOLD); + CUDA_CHECK_LAST_ERROR(ovo_medium_unsorted_rank_kernel); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh new file mode 100644 index 00000000..006002b9 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -0,0 +1,104 @@ +#pragma once + +/** Count nonzeros per column from CSR. One thread per row. */ +template +__global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, + int* __restrict__ col_counts, + int n_rows, int n_cols) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)indices[p]; + if (c < n_cols) atomicAdd(&col_counts[c], 1); + } +} + +/** + * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). + * write_pos[c - col_start] must be initialized to the prefix-sum offset + * for column c. Each thread atomically claims a unique destination slot. + */ +template +__global__ void csr_scatter_to_csc_kernel( + const InT* __restrict__ data, const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, + InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + int col_start, int col_stop) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + // Binary search for col_start (overflow-safe midpoint) + IndptrT lo = rs, hi = re; + while (lo < hi) { + IndptrT m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + for (IndptrT p = lo; p < re; ++p) { + int c = (int)indices[p]; + if (c >= col_stop) break; + int dest = atomicAdd(&write_pos[c - col_start], 1); + csc_vals[dest] = data[p]; + csc_row_idx[dest] = row; + } +} + +/** + * Decide whether to use shared or global memory for OVR rank accumulators. + * Returns the smem size to request and sets use_gmem accordingly. + */ +static int query_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + return cached; +} + +static size_t ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + // Fall back to global memory accumulators; only need warp buf in smem + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Decide smem-vs-gmem for the sparse OVR rank kernel. Two accumulator + * arrays (grp_sums + grp_nz_count) of size n_groups each plus warp buf. + */ +static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Fill sort values with row indices [0,1,...,n_rows-1] per column. + * Grid: (n_cols,), block: 256 threads. + */ +__global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + int* out = vals + (long long)col * n_rows; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + out[i] = i; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh new file mode 100644 index 00000000..0f74a2c8 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -0,0 +1,861 @@ +#pragma once + +/** + * Sparse-aware host-streaming CSC OVR pipeline. + * + * Like ovr_streaming_csc_host_impl but sorts only stored nonzeros per column + * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead + * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. + */ +template +static void ovr_sparse_csc_host_streaming_impl( + const InT* h_data, const int* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + int* d_sparse_indices; + int* d_seg_offsets; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + // Transfer group codes + sizes once + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + // Pre-compute rebased per-batch offsets and upload once (avoids per-batch + // H2D copy from a transient host buffer). + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + // In gmem mode the sparse rank kernel accumulates into rank_sums directly + // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). + for (int s = 0; s < n_streams; s++) { + if (rank_use_gmem) { + bufs[s].d_nz_scratch = + pool.alloc((size_t)n_groups * sub_batch_cols); + } else { + bufs[s].d_nz_scratch = nullptr; + } + } + + // Pin only the host input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(int)); + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = (int)(ptr_end - ptr_start); + + // H2D: transfer sparse data for this column range (native dtype) + if (batch_nnz > 0) { + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + // D2D: copy this batch's rebased offsets from the pre-uploaded buffer + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_seg_offsets, d_group_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + // CUB sort only stored nonzeros (float32 keys) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_sparse_data_f32, buf.keys_out, + buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, + buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (stats already captured above) + if (rank_use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // D2D: scatter sub-batch results into caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSC streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware host-streaming CSR OVR pipeline. +// ============================================================================ + +/** + * Host CSR variant of the sparse OVR stream. + * + * The CSR input stays in host memory. We count columns once on the CPU, then + * use mapped pinned CSR arrays for bounded per-column-batch CSR->CSC scatter + * on the GPU. This avoids both a full host->device sparse upload and any + * whole-matrix CSR->CSC conversion. + */ +template +static void ovr_sparse_csr_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + RmmPool pool; + size_t total_nnz = (size_t)h_indptr[n_rows]; + + // ---- Phase 0: CPU planning in native CSR order ---- + std::vector h_col_counts(n_cols, 0); + for (int row = 0; row < n_rows; row++) { + IndptrT rs = h_indptr[row]; + IndptrT re = h_indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)h_indices[p]; + if (c >= 0 && c < n_cols) h_col_counts[c]++; + } + } + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: allocate per-stream bounded work buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + size_t per_stream_bytes = + max_batch_nnz * (sizeof(InT) + sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + if (compute_sq_sums) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (compute_nnz) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (rank_use_gmem) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Pin the source CSR arrays as mapped memory. The scatter kernel reads + // only the requested column window from each row. + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (total_nnz > 0) { + pin_data = + HostRegisterGuard(const_cast(h_data), total_nnz * sizeof(InT), + cudaHostRegisterMapped); + pin_indices = HostRegisterGuard(const_cast(h_indices), + total_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + IndptrT* d_indptr_full = pool.alloc(n_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; + int* write_pos; + InT* csc_vals_orig; + float* csc_vals_f32; + int* csc_row_idx; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_sq_sums; + double* sub_group_nnz; + double* d_nz_scratch; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals_orig = pool.alloc(max_batch_nnz); + bufs[s].csc_vals_f32 = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: bounded CSR->CSC scatter + GPU rank batches ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = (int)h_batch_nnz[b]; + + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + csr_scatter_to_csc_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, buf.write_pos, + buf.csc_vals_orig, buf.csc_row_idx, n_rows, col, + col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + } + + launch_ovr_cast_and_accumulate_sparse( + buf.csc_vals_orig, buf.csc_vals_f32, buf.csc_row_idx, + buf.col_offsets, d_group_codes, buf.sub_group_sums, + buf.sub_group_sq_sums, buf.sub_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals_f32, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, + d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, + rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.sub_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.sub_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.sub_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSR streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSC OVR streaming (sort only stored nonzeros) +// ============================================================================ + +static void ovr_sparse_csc_streaming_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // Read indptr to host for batch planning + std::vector h_indptr(n_cols + 1); + cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch for buffer sizing + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + + RmmPool pool; + struct StreamBuf { + float* keys_out; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + int ptr_start = h_indptr[col]; + int ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = ptr_end - ptr_start; + + // Compute rebased segment offsets on GPU (avoids host pinned-buffer + // race) + { + int count = sb_cols + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + csc_indptr, buf.seg_offsets, col, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Sort only stored values (keys=data, vals=row_indices) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, csc_data + ptr_start, buf.keys_out, + csc_indices + ptr_start, buf.vals_out, batch_nnz, sb_cols, + buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSR OVR streaming (partial CSR→CSC transpose per sub-batch) +// ============================================================================ + +/** + * Sparse-aware OVR streaming pipeline for GPU CSR data. + * + * Phase 0: One histogram kernel counts nnz per column. D2H + host prefix sums + * give exact per-batch nnz and max_batch_nnz for buffer sizing. + * Phase 1: Allocate per-stream buffers sized to max_batch_nnz. + * Phase 2: For each sub-batch: scatter CSR→CSC (partial transpose via + * atomics) → CUB sort only nonzeros → sparse rank kernel. + * + * Compared to the dense CSR path, sort work drops by ~1/sparsity. + */ +static void ovr_sparse_csr_streaming_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // ---- Phase 0: Planning — count nnz per column via histogram ---- + RmmPool pool; + int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); + { + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_col_histogram_kernel<<>>( + csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); + } + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), + cudaMemcpyDeviceToHost); + + // Per-batch prefix sums on host + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + + // Flat array: n_batches × (sub_batch_cols + 1) offsets + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + off[0] = 0; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + // Upload all batch offsets to GPU in one shot (~20 KB) + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: Allocate per-stream buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + // CSR path needs 4 sort arrays per stream (scatter intermediates + + // CUB output). Fit stream count to available GPU memory. + size_t per_stream_bytes = + max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets + int* write_pos; // [sub_batch_cols] atomic write counters + float* csc_vals; // [max_batch_nnz] transposed values + int* csc_row_idx; // [max_batch_nnz] transposed row indices + float* keys_out; // [max_batch_nnz] CUB sort output + int* vals_out; // [max_batch_nnz] CUB sort output + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: Stream loop ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = (int)h_batch_nnz[b]; + + // D2D copy pre-computed col_offsets for this batch + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Initialize write_pos = col_offsets[0..sb_cols-1] (same D2D source) + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + // Scatter CSR → CSC layout for this sub-batch + csr_scatter_to_csc_kernel<<>>( + csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, + buf.csc_row_idx, n_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + + // CUB sort only the nonzeros + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals, buf.keys_out, buf.csc_row_idx, + buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, + buf.col_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse CSR ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu new file mode 100644 index 00000000..19f1ef57 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -0,0 +1,292 @@ +#include +#include + +#include + +#include "../nb_types.h" +#include "wilcoxon_fast_common.cuh" +#include "wilcoxon_sparse_kernels.cuh" +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovr_sparse.cuh" +#include "kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_ovo_kernels.cuh" +#include "wilcoxon_ovo_device_sparse.cuh" +#include "wilcoxon_ovo_host_sparse.cuh" + +using namespace nb::literals; + +template +void register_sparse_bindings(nb::module_& m) { + m.doc() = "Sparse-native host Wilcoxon CUDA kernels"; + + m.def( + "ovr_sparse_csc_device", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csc_streaming_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_sparse_csr_device", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csr_streaming_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csc_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, + int64_t); +#undef RSC_OVR_SPARSE_CSC_HOST_BINDING + +#define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csr_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64", float, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64", float, int64_t, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64_i64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64", double, int, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVR_SPARSE_CSR_HOST_BINDING + + m.def( + "ovo_streaming_csc_device", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c ref_row_map, + gpu_array_c grp_row_map, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csc_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + ref_row_map.data(), grp_row_map.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "ref_row_map"_a, + "grp_row_map"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_csr_device", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c ref_row_ids, + gpu_array_c grp_row_ids, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csr_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "ref_row_ids"_a, + "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + +#define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_map, \ + host_array h_grp_row_map, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ + int sub_batch_cols) { \ + ovo_streaming_csc_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_ref_row_map.data(), h_grp_row_map.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + compute_sq_sums, compute_nnz, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int, int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64", float, int64_t, + int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSC_HOST_BINDING + +#define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_full_rows, \ + int n_ref, int n_all_grp, int n_cols, int n_test, \ + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, bool compute_sums, int sub_batch_cols) { \ + ovo_streaming_csr_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), n_full_rows, \ + h_ref_row_ids.data(), n_ref, h_grp_row_ids.data(), \ + h_grp_offsets.data(), n_all_grp, n_test, d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_cols, \ + n_groups_stats, compute_tie_corr, compute_sq_sums, \ + compute_nnz, compute_sums, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ + "h_grp_row_ids"_a, "h_grp_offsets"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ + "n_full_rows"_a, "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_test"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "compute_sums"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int, int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64", float, int64_t, + int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSR_HOST_BINDING +} + +NB_MODULE(_wilcoxon_sparse_cuda, m) { + REGISTER_GPU_BINDINGS(register_sparse_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh new file mode 100644 index 00000000..b0e40fdc --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -0,0 +1,651 @@ +#pragma once + +#include + +/** + * Fused rank-sum kernel: walk sorted data, compute per-group rank sums + * and tie correction without materializing a rank matrix. + * + * Each thread processes a CONTIGUOUS chunk of sorted elements, detecting + * tie groups by adjacent comparison (sequential access, no binary search). + * Cross-boundary ties are resolved via binary search at chunk boundaries. + * + * When use_gmem is false, per-group accumulators live in shared memory + * (fast atomics, limited to ~1500 groups on 48 KB devices). When use_gmem + * is true, accumulators write directly to ``rank_sums`` in global memory, + * supporting an arbitrary number of groups. The caller must pre-zero + * ``rank_sums`` before launching in the gmem path. + * + * Shared memory layout: + * use_gmem=false: (n_groups + 32) doubles (accumulators + warp buf) + * use_gmem=true: 32 doubles (warp buf only) + */ +__global__ void rank_sums_from_sorted_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { + int col = blockIdx.x; + if (col >= n_cols) return; + + extern __shared__ double smem[]; + + double* grp_sums; + if (use_gmem) { + // Global memory path: write directly to output (must be pre-zeroed) + grp_sums = rank_sums + (size_t)col; // stride: n_cols + } else { + // Shared memory path: per-block accumulators + grp_sums = smem; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + } + __syncthreads(); + } + + const float* sv = sorted_vals + (size_t)col * n_rows; + const int* si = sorted_row_idx + (size_t)col * n_rows; + + int chunk = (n_rows + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > n_rows) my_end = n_rows; + + double local_tie_sum = 0.0; + + // Stride for accumulator indexing: 1 for shared mem, n_cols for global mem + int acc_stride = use_gmem ? n_cols : 1; + + int i = my_start; + while (i < my_end) { + double val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + int lo = 0, hi = i; + while (lo < hi) { + int mid = lo + (hi - lo) / 2; + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < n_rows && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = n_rows - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + + __syncthreads(); + + // Copy shared memory accumulators to global output (smem path only) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; + } + } + + if (compute_tie_corr) { + // Warp buf sits after accumulator array in shared memory. + // gmem path: warp buf starts at smem[0]. + // smem path: n_groups doubles, then warp buf. + int warp_buf_off = use_gmem ? 0 : n_groups; + double* warp_buf = smem + warp_buf_off; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - val / denom) : 1.0; + } + } + } +} + +/** + * Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. + * + * Sparse rank_genes_groups now rejects explicit negative sparse values before + * reaching CUDA, so after CUB sort each column segment is: + * [stored_zeros..., positives...] + * + * Implicit zeros (n_rows - nnz_stored) join stored zeros as the first tie + * block. The kernel ranks only stored positive values and adds each group's + * zero contribution analytically. + * + * Full sorted array (conceptual): + * [ALL_zeros (stored+implicit)..., positives...] + * + * Rank offsets: + * positive at stored pos i : full pos = i + n_implicit_zero + * zeros : avg rank = (total_zero + 1) / 2 + * + * Shared-memory layout (doubles): + * grp_sums[n_groups] rank-sum accumulators + * grp_nz_count[n_groups] nonzero-per-group counters + * warp_buf[32] tie-correction reduction scratch + * + * Grid: (sb_cols,) Block: (tpb,) + */ +__global__ void rank_sums_sparse_ovr_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, + const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, const double* __restrict__ group_sizes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, + double* __restrict__ nz_count_scratch, int n_rows, int sb_cols, + int n_groups, bool compute_tie_corr, bool use_gmem) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + int nnz_stored = seg_end - seg_start; + + const float* sv = sorted_vals + seg_start; + const int* si = sorted_row_idx + seg_start; + + extern __shared__ double smem[]; + double* grp_sums; + double* grp_nz_count; + // Accumulator stride: 1 for shared mem (dense per-block), sb_cols for + // gmem (row-major layout (n_groups, sb_cols) shared across blocks). + int acc_stride; + + if (use_gmem) { + // Output rank_sums doubles as accumulator (pre-zeroed by caller). + grp_sums = rank_sums + (size_t)col; + grp_nz_count = nz_count_scratch + (size_t)col; + acc_stride = sb_cols; + } else { + grp_sums = smem; + grp_nz_count = smem + n_groups; + acc_stride = 1; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + grp_nz_count[g] = 0.0; + } + __syncthreads(); + } + + // --- Find stored zero range: pos_start = first val > 0 --- + __shared__ int sh_pos_start; + if (threadIdx.x == 0) { + // Binary search: first index where sv[i] > 0.0 + int lo = 0, hi = nnz_stored; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] <= 0.0f) + lo = mid + 1; + else + hi = mid; + } + sh_pos_start = lo; + } + __syncthreads(); + + int pos_start = sh_pos_start; + int n_stored_zero = pos_start; + int n_implicit_zero = n_rows - nnz_stored; + int total_zero = n_implicit_zero + n_stored_zero; + double zero_avg_rank = (total_zero > 0) ? (total_zero + 1.0) / 2.0 : 0.0; + + // Rank offset for positive stored values: + // full_pos(i) = i + n_implicit_zero for i >= pos_start + // So avg_rank for tie group [a,b) of positives: + // = n_implicit_zero + (a + b + 1) / 2 + int offset_pos = n_implicit_zero; + + // --- Count stored positive values per group --- + for (int i = pos_start + threadIdx.x; i < nnz_stored; i += blockDim.x) { + int grp = group_codes[si[i]]; + if (grp < n_groups) { + atomicAdd(&grp_nz_count[grp * acc_stride], 1.0); + } + } + __syncthreads(); + + // --- Zero-rank contribution per group --- + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + double n_zero_in_g = group_sizes[g] - grp_nz_count[g * acc_stride]; + grp_sums[g * acc_stride] = n_zero_in_g * zero_avg_rank; + } + __syncthreads(); + + // --- Walk stored positives only and compute ranks --- + int n_pos = nnz_stored - pos_start; + int chunk = (n_pos + blockDim.x - 1) / blockDim.x; + int my_start = pos_start + threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > nnz_stored) my_end = nnz_stored; + + double local_tie_sum = 0.0; + + int i = my_start; + while (i < my_end) { + float val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + // Binary search for first occurrence + int lo = pos_start, hi = i; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < nnz_stored && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = nnz_stored - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + + double avg_rank = (double)offset_pos + + (double)(tie_global_start + tie_global_end + 1) / 2.0; + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + + __syncthreads(); + + // Write rank sums to global output (smem path only — gmem path is direct) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * sb_cols + col] = grp_sums[g]; + } + } + + // Tie correction: warp + block reduction + if (compute_tie_corr) { + // Zero tie group contribution (one thread only) + if (threadIdx.x == 0 && total_zero > 1) { + double tz = (double)total_zero; + local_tie_sum += tz * tz * tz - tz; + } + + // smem path: warp buf after both accumulator arrays (2 * n_groups). + // gmem path: accumulators are in gmem, warp buf starts at smem[0]. + int warp_buf_off = use_gmem ? 0 : 2 * n_groups; + double* warp_buf = smem + warp_buf_off; + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - v / denom) : 1.0; + } + } + } +} + +/** + * Decide whether the host cast+stats kernels can use per-block shared memory + * accumulators. Large group counts exceed the dynamic smem launch limit, so + * those cases fall back to direct global-memory atomics after zeroing the + * per-stream output buffers. + */ +static int wilcoxon_cast_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + return cached; +} + +static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, + bool compute_nnz, bool& use_gmem) { + int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); + size_t need = (size_t)n_arrays * n_groups * sizeof(double); + if (need <= (size_t)wilcoxon_cast_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 0; +} + +/** + * Pre-sort cast-and-accumulate kernel for dense OVR host streaming. + * + * Reads a sub-batch block in its native host dtype (InT = float or double), + * writes a float32 copy used as the sort input, and accumulates per-group + * sum, sum-of-squares and nonzero counts in float64. Stats are derived + * from the original-precision values so float64 host input keeps its + * precision while the sort still runs on float32 keys. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles (s_sum, s_sq, s_nnz). + */ +template +__global__ void ovr_cast_and_accumulate_dense_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +template +__global__ void ovr_cast_and_accumulate_dense_global_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +/** + * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. + * + * Sub-batch CSC data is laid out contiguously: values for column c live + * at positions [col_seg_offsets[c], col_seg_offsets[c+1]). For each + * stored value, read the native-dtype InT, write a float32 copy for the + * CUB sort, and accumulate per-group sum/sum-sq/nnz in float64. Implicit + * zeros contribute nothing to any of these stats. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles. + */ +template +__global__ void ovr_cast_and_accumulate_sparse_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +template +__global__ void ovr_cast_and_accumulate_sparse_global_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +template +static void launch_ovr_cast_and_accumulate_dense( + const InT* d_block_orig, float* d_block_f32, const int* d_group_codes, + double* d_group_sums, double* d_group_sq_sums, double* d_group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums, + bool compute_nnz, int tpb, size_t smem_cast, bool use_gmem, + cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_dense_global_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_global_kernel); + } else { + ovr_cast_and_accumulate_dense_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + } +} + +template +static void launch_ovr_cast_and_accumulate_sparse( + const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, + const int* d_col_offsets, const int* d_group_codes, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int sb_cols, int n_groups, + bool compute_sq_sums, bool compute_nnz, int tpb, size_t smem_cast, + bool use_gmem, cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_sparse_global_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_global_kernel); + } else { + ovr_cast_and_accumulate_sparse_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + } +} diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 0b9753a3..d399a301 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -21,6 +21,102 @@ ] +class _LazyRankGenesColumn: + def __init__( + self, + values: np.ndarray | None = None, + *, + var_names: np.ndarray | None = None, + gene_indices: np.ndarray | None = None, + dtype: str | np.dtype, + ) -> None: + self._values = values + self._var_names = var_names + self._gene_indices = gene_indices + self._dtype = np.dtype(dtype) + + def __len__(self) -> int: + if self._values is not None: + return int(self._values.shape[0]) + return int(self._gene_indices.shape[0]) + + def __getitem__(self, key): + if self._values is not None: + return self._values[key] + return self._var_names[self._gene_indices[key]] + + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + + def __array__(self, dtype=None, copy=None) -> np.ndarray: + if self._values is not None: + arr = np.asarray(self._values, dtype=self._dtype) + else: + arr = np.asarray(self._var_names[self._gene_indices], dtype=self._dtype) + if dtype is not None: + arr = np.asarray(arr, dtype=dtype) + if copy: + arr = arr.copy() + return arr + + +class _LazyRankGenesRecords(dict): + def __init__( + self, group_names: np.ndarray, columns: dict[str, object], dtype: str | np.dtype + ) -> None: + super().__init__(columns) + self._group_names = tuple(str(name) for name in group_names) + self._dtype = np.dtype([(name, np.dtype(dtype)) for name in self._group_names]) + + @property + def dtype(self) -> np.dtype: + return self._dtype + + def __getitem__(self, key): + if isinstance(key, str): + return super().__getitem__(key) + return np.asarray(self)[key] + + def __array__(self, dtype=None, copy=None) -> np.ndarray: + out = np.empty(len(next(iter(self.values()))) if self else 0, dtype=self._dtype) + for name in self._group_names: + out[name] = np.asarray(super().__getitem__(name)) + if dtype is not None: + out = np.asarray(out, dtype=dtype) + if copy: + out = out.copy() + return out + + def copy(self) -> np.ndarray: + return np.asarray(self).copy() + + +def _array_result_to_lazy_records( + arrays: dict[str, object], field: str, dtype: str | np.dtype +) -> _LazyRankGenesRecords: + group_names = arrays["group_names"] + values = arrays[field] + columns = { + str(group_name): _LazyRankGenesColumn(values[row], dtype=dtype) + for row, group_name in enumerate(group_names) + } + return _LazyRankGenesRecords(group_names, columns, dtype) + + +def _array_result_to_lazy_names(arrays: dict[str, object]) -> _LazyRankGenesRecords: + group_names = arrays["group_names"] + var_names = arrays["var_names"] + gene_indices = arrays["gene_indices"] + columns = { + str(group_name): _LazyRankGenesColumn( + var_names=var_names, gene_indices=gene_indices[row], dtype=object + ) + for row, group_name in enumerate(group_names) + } + return _LazyRankGenesRecords(group_names, columns, object) + + def rank_genes_groups( adata: AnnData, groupby: str, @@ -37,17 +133,21 @@ def rank_genes_groups( corr_method: _CorrMethod = "benjamini-hochberg", tie_correct: bool = False, use_continuity: bool = False, + return_u_values: bool = False, layer: str | None = None, chunk_size: int | None = None, pre_load: bool = False, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + skip_empty_groups: bool = False, **kwds, ) -> None: """ Rank genes for characterizing groups using GPU acceleration. - Expects logarithmized data. + Expects nonnegative expression data. Log1p/log-normalized data is expected + for biologically meaningful log fold changes; sparse inputs with explicit + negative values are rejected. .. note:: **Dask support:** `'t-test'`, `'t-test_overestim_var'`, and @@ -101,6 +201,10 @@ def rank_genes_groups( z-scores. Subtracts 0.5 from ``|R - E[R]|`` before dividing by the standard deviation, matching :func:`scipy.stats.mannwhitneyu` default behavior. + return_u_values + For `'wilcoxon'`, store Mann-Whitney U statistics in `scores` instead + of z-scores. P-values are still computed from the z-score normal + approximation using the selected tie and continuity settings. layer Key from `adata.layers` whose value will be used to perform tests on. chunk_size @@ -119,15 +223,22 @@ def rank_genes_groups( ``None`` (default) uses ``'auto'`` for in-memory arrays and ``'log1p'`` for Dask arrays (to avoid a costly data scan). ``'log1p'`` uses a fixed [0, 15] range suitable for most log1p-normalized data. - ``'auto'`` computes the actual data range. Use this for z-scored - or unnormalized data. + ``'auto'`` computes the actual data range. Use this for nonnegative + expression data outside the fixed log1p range. + skip_empty_groups + Skip selected groups with fewer than two observations after filtering. + This is useful for perturbation workflows where a per-cell-type slice + keeps categories that are empty or singleton in that slice. **kwds Additional arguments passed to the method. For `'logreg'`, these are passed to :class:`cuml.linear_model.LogisticRegression`. Returns ------- - Updates `adata` with the following fields: + Updates `adata` with the following fields. Rank result fields are lazy + Scanpy-compatible record objects: group fields can be indexed like + structured arrays, while full structured arrays are materialized only when + requested through NumPy conversion or `.copy()`. `adata.uns['rank_genes_groups' | key_added]['names']` Structured array to be indexed by group id storing the gene @@ -135,7 +246,8 @@ def rank_genes_groups( `adata.uns['rank_genes_groups' | key_added]['scores']` Structured array to be indexed by group id storing the z-score underlying the computation of a p-value for each gene for each - group. Ordered according to scores. + group, or the Mann-Whitney U statistic when + `return_u_values=True`. Ordered according to scores. `adata.uns['rank_genes_groups' | key_added]['logfoldchanges']` Structured array to be indexed by group id storing the log2 fold change for each gene for each group. @@ -154,6 +266,13 @@ def rank_genes_groups( msg = "corr_method must be either 'benjamini-hochberg' or 'bonferroni'." raise ValueError(msg) + if "return_format" in kwds: + msg = ( + "return_format has been removed; rank_genes_groups always writes " + "lazy Scanpy-compatible results to adata.uns." + ) + raise TypeError(msg) + if method is None: method = "t-test" @@ -170,6 +289,10 @@ def rank_genes_groups( ) raise ValueError(msg) + if return_u_values and method != "wilcoxon": + msg = "return_u_values is only supported for method='wilcoxon'." + raise ValueError(msg) + if key_added is None: key_added = "rank_genes_groups" @@ -197,6 +320,7 @@ def rank_genes_groups( layer=layer, comp_pts=pts, pre_load=pre_load, + skip_empty_groups=skip_empty_groups, ) # Determine n_genes_user @@ -211,25 +335,14 @@ def rank_genes_groups( rankby_abs=rankby_abs, tie_correct=tie_correct, use_continuity=use_continuity, + return_u_values=return_u_values, chunk_size=chunk_size, n_bins=n_bins, bin_range=bin_range, **kwds, ) - # Build output - test_obj.stats.columns = test_obj.stats.columns.swaplevel() - - dtypes = { - "names": "U50", - "scores": "float32", - "logfoldchanges": "float32", - "pvals": "float64", - "pvals_adj": "float64", - } - - adata.uns[key_added] = {} - adata.uns[key_added]["params"] = { + params = { "groupby": groupby, "reference": reference, "method": method, @@ -237,8 +350,28 @@ def rank_genes_groups( "layer": layer, "corr_method": corr_method, } + if method == "wilcoxon": + params["tie_correct"] = tie_correct + params["return_u_values"] = return_u_values + + arrays = test_obj.stats_arrays or {} + adata.uns[key_added] = {"params": params} + if arrays and len(arrays.get("group_names", ())) > 0: + adata.uns[key_added]["names"] = _array_result_to_lazy_names(arrays) + for col, dtype in { + "scores": "float32", + "logfoldchanges": "float32", + "pvals": "float64", + "pvals_adj": "float64", + }.items(): + if col in arrays: + values = arrays[col] + if hasattr(values, "dtype"): + dtype = values.dtype + adata.uns[key_added][col] = _array_result_to_lazy_records( + arrays, col, dtype + ) - # Store pts results if computed if test_obj.pts is not None: groups_names = [str(name) for name in test_obj.groups_order] adata.uns[key_added]["pts"] = pd.DataFrame( @@ -249,14 +382,7 @@ def rank_genes_groups( test_obj.pts_rest.T, index=test_obj.var_names, columns=groups_names ) - if method == "wilcoxon": - adata.uns[key_added]["params"]["tie_correct"] = tie_correct - - for col in test_obj.stats.columns.levels[0]: - if col in dtypes: - adata.uns[key_added][col] = test_obj.stats[col].to_records( - index=False, column_dtypes=dtypes[col] - ) + return None if TYPE_CHECKING: @@ -285,7 +411,7 @@ def rank_genes_groups_logreg( layer: str | None = None, **kwds, ) -> None: - rank_genes_groups( + return rank_genes_groups( adata, groupby, groups=groups, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index c65bbf7c..acfbe2e2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -1,18 +1,42 @@ from __future__ import annotations +import os +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Literal, assert_never import cupy as cp import numpy as np import pandas as pd -from statsmodels.stats.multitest import multipletests from rapids_singlecell._compat import DaskArray from rapids_singlecell.get import X_to_GPU from rapids_singlecell.get._aggregated import Aggregate from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import EPS, _select_groups, _select_top_n +from ._utils import EPS, _check_sparse_nonnegative, _select_groups + +_FDR_BH_REVERSE_CUMMIN_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ void fdr_bh_reverse_cummin(double* values, const int n_cols) { + const int row = blockIdx.x; + double running = 1.0; + double* row_values = values + static_cast(row) * n_cols; + for (int col = n_cols - 1; col >= 0; --col) { + double value = row_values[col]; + if (!(value == value)) { + value = 1.0; + } + if (value < running) { + running = value; + } + row_values[col] = running; + } +} +""", + "fdr_bh_reverse_cummin", +) +_RANK_SORT_MIN_ELEMENTS = 1_000_000 +_RANK_SORT_MAX_WORKERS = 64 if TYPE_CHECKING: from collections.abc import Iterable @@ -38,6 +62,7 @@ def __init__( layer: str | None = None, comp_pts: bool = False, pre_load: bool = False, + skip_empty_groups: bool = False, ) -> None: # Handle groups parameter if groups == "all" or groups is None: @@ -63,7 +88,10 @@ def __init__( raise ValueError(msg) self.groups_order, self.group_codes, self.group_sizes = _select_groups( - self.labels, selected + self.labels, + selected, + reference=reference, + skip_empty_groups=skip_empty_groups, ) # Get data matrix @@ -91,6 +119,8 @@ def __init__( self.X = self.X[:, mask_var] self.var_names = self.var_names[mask_var] + _check_sparse_nonnegative(self.X) + self.pre_load = pre_load self.ireference = None @@ -100,6 +130,7 @@ def __init__( # Set up expm1 function based on log base self.is_log1p = "log1p" in adata.uns base = adata.uns.get("log1p", {}).get("base") + self._log1p_base = base if base is not None: self.expm1_func = lambda x: np.expm1(x * np.log(base)) else: @@ -115,8 +146,14 @@ def __init__( self.pts_rest: np.ndarray | None = None self.stats: pd.DataFrame | None = None + self.stats_arrays: dict[str, object] | None = None + self._store_wilcoxon_gpu_result = False + self._wilcoxon_gpu_result: ( + tuple[np.ndarray, cp.ndarray, cp.ndarray, cp.ndarray | None] | None + ) = None self._compute_stats_in_chunks: bool = False self._ref_chunk_computed: set[int] = set() + self._score_dtype = np.dtype(np.float32) def _init_stats_arrays(self, n_genes: int) -> None: """Pre-allocate stats arrays before chunk loop.""" @@ -190,16 +227,18 @@ def _basic_stats(self) -> None: # Compute rest statistics if reference='rest' if self.ireference is None: - n_rest = n.sum() - n - means_rest = (sums.sum(axis=0) - sums) / n_rest - rest_ss = (sq_sums.sum(axis=0) - sq_sums) - n_rest * means_rest**2 + n_rest = cp.float64(self.X.shape[0]) - n + total_sums = result["sum"].sum(axis=0, keepdims=True) + total_sq_sums = result["sq_sum"].sum(axis=0, keepdims=True) + means_rest = (total_sums - sums) / n_rest + rest_ss = (total_sq_sums - sq_sums) - n_rest * means_rest**2 vars_rest = cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0) self.means_rest = cp.asnumpy(means_rest) self.vars_rest = cp.asnumpy(vars_rest) if self.comp_pts: - total_count = (pts * n).sum(axis=0) + total_count = result["count_nonzero"].sum(axis=0, keepdims=True) self.pts_rest = cp.asnumpy((total_count - pts * n) / n_rest) else: self.pts_rest = None @@ -325,6 +364,7 @@ def wilcoxon( tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" from ._wilcoxon import wilcoxon @@ -334,6 +374,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) def wilcoxon_binned( @@ -375,6 +416,7 @@ def compute_statistics( chunk_size: int | None = None, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + return_u_values: bool = False, **kwds, ) -> None: """Compute statistics for all groups.""" @@ -385,17 +427,28 @@ def compute_statistics( }: self.X = X_to_GPU(self.X) + n_genes = self.X.shape[1] + if n_genes_user is None: + n_genes_user = n_genes + if method in {"t-test", "t-test_overestim_var"}: test_results = self.t_test(method) elif method == "wilcoxon": if isinstance(self.X, DaskArray): msg = "Wilcoxon test is not supported for Dask arrays. Please convert your data to CuPy arrays." raise ValueError(msg) - test_results = self.wilcoxon( - tie_correct=tie_correct, - use_continuity=use_continuity, - chunk_size=chunk_size, - ) + self._score_dtype = np.dtype(np.float64 if return_u_values else np.float32) + self._wilcoxon_gpu_result = None + self._store_wilcoxon_gpu_result = n_genes_user is not None + try: + test_results = self.wilcoxon( + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + finally: + self._store_wilcoxon_gpu_result = False elif method == "wilcoxon_binned": test_results = self.wilcoxon_binned( tie_correct=tie_correct, @@ -409,58 +462,225 @@ def compute_statistics( else: assert_never(method) - n_genes = self.X.shape[1] + if not test_results and self._wilcoxon_gpu_result is None: + self.stats_arrays = { + "group_indices": np.empty(0, dtype=np.intp), + "group_names": np.empty(0, dtype=object), + "var_names": np.asarray(self.var_names), + "gene_indices": np.empty((0, n_genes_user), dtype=np.intp), + } + self.stats = None + return + + if self._wilcoxon_gpu_result is not None: + group_indices, scores_gpu, pvals_gpu, logfoldchanges_gpu = ( + self._wilcoxon_gpu_result + ) + try: + self._compute_statistics_gpu_arrays( + group_indices, + scores_gpu, + pvals_gpu, + logfoldchanges_gpu, + corr_method=corr_method, + n_genes_user=n_genes_user, + n_genes=n_genes, + rankby_abs=rankby_abs, + ) + finally: + self._wilcoxon_gpu_result = None + return - # Collect all stats data first to avoid DataFrame fragmentation - stats_data: dict[tuple[str, str], np.ndarray] = {} + self._compute_statistics_arrays( + test_results, + corr_method=corr_method, + n_genes_user=n_genes_user, + n_genes=n_genes, + rankby_abs=rankby_abs, + ) - for group_index, scores, pvals in test_results: - group_name = str(self.groups_order[group_index]) + @staticmethod + def _rank_indices_matrix(scores: np.ndarray, n_top: int) -> np.ndarray: + if n_top >= scores.shape[1]: + return _RankGenes._argsort_desc_matrix(scores) + partition = np.argpartition(scores, -n_top, axis=1)[:, -n_top:] + row_ids = np.arange(scores.shape[0])[:, None] + order = np.argsort(scores[row_ids, partition], axis=1)[:, ::-1] + return partition[row_ids, order] + + @staticmethod + def _argsort_desc_matrix(scores: np.ndarray) -> np.ndarray: + n_rows, n_cols = scores.shape + n_elements = n_rows * n_cols + n_workers = min(_RANK_SORT_MAX_WORKERS, os.cpu_count() or 1, n_rows) + if n_workers <= 1 or n_elements < _RANK_SORT_MIN_ELEMENTS: + return np.argsort(scores, axis=1)[:, ::-1] + + chunks = np.linspace(0, n_rows, n_workers + 1, dtype=np.intp) + indices = np.empty((n_rows, n_cols), dtype=np.intp) + + def sort_chunk(chunk_index: int) -> None: + start = int(chunks[chunk_index]) + stop = int(chunks[chunk_index + 1]) + if start < stop: + indices[start:stop] = np.argsort(scores[start:stop], axis=1)[:, ::-1] + + with ThreadPoolExecutor(max_workers=n_workers) as executor: + list(executor.map(sort_chunk, range(n_workers))) + return indices + + @staticmethod + def _fdr_bh_matrix(pvals: np.ndarray) -> np.ndarray: + pvals_clean = np.array(pvals, copy=True) + pvals_clean[np.isnan(pvals_clean)] = 1.0 + order = np.argsort(pvals_clean, axis=1) + sorted_p = np.take_along_axis(pvals_clean, order, axis=1) + n_tests = sorted_p.shape[1] + scale = n_tests / np.arange(1, n_tests + 1, dtype=np.float64) + corrected_sorted = sorted_p * scale + corrected_sorted = np.minimum.accumulate(corrected_sorted[:, ::-1], axis=1)[ + :, ::-1 + ] + corrected_sorted[corrected_sorted > 1.0] = 1.0 + corrected = np.empty_like(corrected_sorted) + np.put_along_axis(corrected, order, corrected_sorted, axis=1) + return corrected + + @staticmethod + def _fdr_bh_matrix_gpu(pvals: cp.ndarray) -> cp.ndarray: + pvals_clean = cp.nan_to_num(pvals, nan=1.0) + order = cp.argsort(pvals_clean, axis=1) + corrected_sorted = cp.take_along_axis(pvals_clean, order, axis=1) + corrected_sorted *= corrected_sorted.shape[1] / cp.arange( + 1, corrected_sorted.shape[1] + 1, dtype=cp.float64 + ) + _FDR_BH_REVERSE_CUMMIN_KERNEL( + (corrected_sorted.shape[0],), + (1,), + (corrected_sorted, np.int32(corrected_sorted.shape[1])), + ) + corrected = cp.empty_like(corrected_sorted) + cp.put_along_axis(corrected, order, corrected_sorted, axis=1) + return corrected - if n_genes_user is not None: - scores_sort = np.abs(scores) if rankby_abs else scores - global_indices = _select_top_n(scores_sort, n_genes_user) + def _compute_statistics_arrays( + self, + test_results: list[tuple[int, NDArray, NDArray]], + *, + corr_method: _CorrMethod, + n_genes_user: int, + n_genes: int, + rankby_abs: bool, + ) -> None: + group_indices = np.asarray([r[0] for r in test_results], dtype=np.intp) + scores = np.vstack([r[1] for r in test_results]) + sort_scores = np.abs(scores) if rankby_abs else scores + top_idx = self._rank_indices_matrix(sort_scores, n_genes_user) + + arrays: dict[str, object] = { + "group_indices": group_indices, + "group_names": np.asarray( + [str(self.groups_order[i]) for i in group_indices], dtype=object + ), + "var_names": np.asarray(self.var_names), + "gene_indices": top_idx.astype(np.intp, copy=False), + "scores": np.take_along_axis(scores, top_idx, axis=1).astype( + self._score_dtype, copy=False + ), + } + + if test_results[0][2] is not None: + pvals = np.vstack([r[2] for r in test_results]) + arrays["pvals"] = np.take_along_axis(pvals, top_idx, axis=1) + if corr_method == "benjamini-hochberg": + pvals_adj = self._fdr_bh_matrix(pvals) + elif corr_method == "bonferroni": + pvals_adj = np.minimum(pvals * n_genes, 1.0) else: - global_indices = slice(None) - - if n_genes_user is not None: - stats_data[group_name, "names"] = np.asarray(self.var_names)[ - global_indices - ] - - stats_data[group_name, "scores"] = scores[global_indices] - - if pvals is not None: - stats_data[group_name, "pvals"] = pvals[global_indices] - if corr_method == "benjamini-hochberg": - pvals_clean = np.array(pvals, copy=True) - pvals_clean[np.isnan(pvals_clean)] = 1.0 - _, pvals_adj, _, _ = multipletests( - pvals_clean, alpha=0.05, method="fdr_bh" - ) - elif corr_method == "bonferroni": - pvals_adj = np.minimum(pvals * n_genes, 1.0) - stats_data[group_name, "pvals_adj"] = pvals_adj[global_indices] - - # Compute logfoldchanges - if self.means is not None: - mean_group = self.means[group_index] - if self.ireference is None: - mean_rest = self.means_rest[group_index] - else: - mean_rest = self.means[self.ireference] - foldchanges = (self.expm1_func(mean_group) + EPS) / ( - self.expm1_func(mean_rest) + EPS + msg = f"Unsupported correction method: {corr_method!r}." + raise ValueError(msg) + arrays["pvals_adj"] = np.take_along_axis(pvals_adj, top_idx, axis=1) + + if self.means is not None: + mean_group = self.means[group_indices] + if self.ireference is None: + mean_rest = self.means_rest[group_indices] + else: + mean_rest = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_group) + EPS) / ( + self.expm1_func(mean_rest) + EPS + ) + logfoldchanges = np.log2(foldchanges) + arrays["logfoldchanges"] = np.take_along_axis( + logfoldchanges, top_idx, axis=1 + ).astype(np.float32, copy=False) + + self.stats_arrays = arrays + self.stats = None + + def _compute_statistics_gpu_arrays( + self, + group_indices: np.ndarray, + scores_gpu: cp.ndarray, + pvals_gpu: cp.ndarray, + logfoldchanges_gpu: cp.ndarray | None, + *, + corr_method: _CorrMethod, + n_genes_user: int, + n_genes: int, + rankby_abs: bool, + ) -> None: + group_indices = np.asarray(group_indices, dtype=np.intp) + scores = cp.asnumpy(scores_gpu) + sort_scores = np.abs(scores) if rankby_abs else scores + top_idx = self._rank_indices_matrix(sort_scores, n_genes_user) + top_idx_gpu = cp.asarray(top_idx) + + arrays: dict[str, object] = { + "group_indices": group_indices, + "group_names": np.asarray( + [str(self.groups_order[i]) for i in group_indices], dtype=object + ), + "var_names": np.asarray(self.var_names), + "gene_indices": top_idx.astype(np.intp, copy=False), + "scores": cp.asnumpy( + cp.take_along_axis(scores_gpu, top_idx_gpu, axis=1).astype( + self._score_dtype, copy=False ) - stats_data[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] + ), + "pvals": cp.asnumpy(cp.take_along_axis(pvals_gpu, top_idx_gpu, axis=1)), + } + + if corr_method == "benjamini-hochberg": + pvals_adj_gpu = self._fdr_bh_matrix_gpu(pvals_gpu) + elif corr_method == "bonferroni": + pvals_adj_gpu = cp.minimum(pvals_gpu * n_genes, 1.0) + else: + msg = f"Unsupported correction method: {corr_method!r}." + raise ValueError(msg) + arrays["pvals_adj"] = cp.asnumpy( + cp.take_along_axis(pvals_adj_gpu, top_idx_gpu, axis=1) + ) + + if logfoldchanges_gpu is not None: + arrays["logfoldchanges"] = cp.asnumpy( + cp.take_along_axis(logfoldchanges_gpu, top_idx_gpu, axis=1).astype( + cp.float32, copy=False ) + ) + elif self.means is not None: + mean_group = self.means[group_indices] + if self.ireference is None: + mean_rest = self.means_rest[group_indices] + else: + mean_rest = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_group) + EPS) / ( + self.expm1_func(mean_rest) + EPS + ) + logfoldchanges = np.log2(foldchanges) + arrays["logfoldchanges"] = np.take_along_axis( + logfoldchanges, top_idx, axis=1 + ).astype(np.float32, copy=False) - # Create DataFrame all at once to avoid fragmentation - if stats_data: - self.stats = pd.DataFrame(stats_data) - self.stats.columns = pd.MultiIndex.from_tuples(self.stats.columns) - if n_genes_user is None: - self.stats.index = self.var_names - else: - self.stats = None + self.stats_arrays = arrays + self.stats = None diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index c4f2c601..4ec37e40 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -18,9 +18,38 @@ MAX_THREADS_PER_BLOCK = 512 +def _check_sparse_nonnegative(X) -> None: + """Reject sparse matrices with explicit negative values. + + Sparse rank_genes_groups code treats missing entries as true expression + zeros. Optimized sparse Wilcoxon paths may rank explicit nonzeros and add + implicit zeros analytically, which is only valid when explicit sparse + values are nonnegative expression values. + """ + if sp.issparse(X): + if X.nnz > 0 and float(X.data.min()) < 0: + msg = ( + "Sparse input contains negative values. rank_genes_groups " + "expects nonnegative expression values; use raw counts or " + "log1p/log-normalized expression, not scaled or centered data." + ) + raise ValueError(msg) + elif cpsp.issparse(X): + if X.nnz > 0 and float(X.data.min()) < 0: + msg = ( + "Sparse input contains negative values. rank_genes_groups " + "expects nonnegative expression values; use raw counts or " + "log1p/log-normalized expression, not scaled or centered data." + ) + raise ValueError(msg) + + def _select_groups( labels: pd.Series, selected: list | None, + *, + reference: str = "rest", + skip_empty_groups: bool = False, ) -> tuple[NDArray, NDArray[np.int32], NDArray[np.int64]]: """Build integer group codes from a categorical Series. @@ -51,6 +80,29 @@ def _select_groups( cat_order = {str(c): i for i, c in enumerate(all_categories)} selected.sort(key=lambda x: cat_order.get(str(x), len(all_categories))) + if skip_empty_groups: + counts = { + str(name): int(count) for name, count in labels.value_counts().items() + } + valid_selected = [group for group in selected if counts.get(str(group), 0) >= 2] + if reference != "rest": + ref_matches = [group for group in selected if str(group) == str(reference)] + if ref_matches: + ref_group = ref_matches[0] + if ref_group not in valid_selected: + msg = ( + f"reference = {reference} has fewer than two samples after " + "filtering and cannot be used for rank_genes_groups." + ) + raise ValueError(msg) + selected = valid_selected + if len(selected) == 0: + msg = ( + "No groups with at least two samples remain after applying " + "skip_empty_groups=True." + ) + raise ValueError(msg) + n_groups = len(selected) groups_order = np.array(selected) @@ -76,7 +128,7 @@ def _select_groups( if invalid_groups: msg = ( f"Could not calculate statistics for groups {', '.join(invalid_groups)} " - "since they only contain one sample." + "since they contain fewer than two samples." ) raise ValueError(msg) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index c14c760d..e20af614 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -4,14 +4,15 @@ from typing import TYPE_CHECKING import cupy as cp +import cupyx.scipy.sparse as cpsp import cupyx.scipy.special as cupyx_special import numpy as np import scipy.sparse as sp from rapids_singlecell._cuda import _wilcoxon_cuda as _wc -from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc +from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs -from ._utils import _choose_chunk_size, _get_column_block +from ._utils import EPS, _choose_chunk_size, _get_column_block if TYPE_CHECKING: from numpy.typing import NDArray @@ -19,6 +20,14 @@ from ._core import _RankGenes MIN_GROUP_SIZE_WARNING = 25 +DEFAULT_WILCOXON_CHUNK_SIZE = 512 +OVO_SORT_GROUP_THRESHOLD = 512 +OVR_HOST_CSC_SUB_BATCH = 512 +OVR_HOST_CSR_SUB_BATCH = 2048 +OVR_DEVICE_CSC_SUB_BATCH = 2048 +OVR_DEVICE_CSR_SUB_BATCH = 2048 +OVO_HOST_SPARSE_SUB_BATCH = 256 +OVO_DEVICE_SPARSE_SUB_BATCH = 128 def _average_ranks( @@ -86,12 +95,307 @@ def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: return correction +def _extract_dense_rows_cols( + X, row_ids: np.ndarray, start: int, stop: int +) -> cp.ndarray: + """Extract a bounded row/column block as F-order CuPy dense memory.""" + if isinstance(X, np.ndarray): + return cp.asarray(X[row_ids, start:stop], order="F") + if isinstance(X, cp.ndarray): + rows = cp.asarray(row_ids, dtype=cp.int32) + return cp.asfortranarray(X[rows, start:stop]) + if isinstance(X, sp.spmatrix | sp.sparray): + return cp.asarray(X[row_ids][:, start:stop].toarray(), order="F") + if cpsp.issparse(X): + rows = cp.asarray(row_ids, dtype=cp.int32) + return cp.asfortranarray(X[rows][:, start:stop].toarray()) + raise TypeError(f"Unsupported matrix type: {type(X)}") + + +def _choose_wilcoxon_chunk_size(requested: int | None, n_genes: int) -> int: + if requested is not None: + return _choose_chunk_size(requested) + return min(DEFAULT_WILCOXON_CHUNK_SIZE, max(1, n_genes)) + + +def _fill_ovo_chunk_stats( + rg: _RankGenes, + ref_block: cp.ndarray, + grp_block: cp.ndarray, + *, + offsets: np.ndarray, + test_group_indices: list[int], + start: int, + stop: int, + group_sizes: NDArray, +) -> None: + if not rg._compute_stats_in_chunks: + return + + ireference = rg.ireference + n_ref = int(group_sizes[ireference]) + ref_mean = ref_block.mean(axis=0) + rg.means[ireference, start:stop] = cp.asnumpy(ref_mean) + if n_ref > 1: + rg.vars[ireference, start:stop] = cp.asnumpy(ref_block.var(axis=0, ddof=1)) + if rg.comp_pts: + ref_nnz = (ref_block != 0).sum(axis=0) + rg.pts[ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) + + for slot, group_index in enumerate(test_group_indices): + begin = int(offsets[slot]) + end = int(offsets[slot + 1]) + n_group = int(group_sizes[group_index]) + group_block = grp_block[begin:end] + group_mean = group_block.mean(axis=0) + rg.means[group_index, start:stop] = cp.asnumpy(group_mean) + if n_group > 1: + rg.vars[group_index, start:stop] = cp.asnumpy( + group_block.var(axis=0, ddof=1) + ) + if rg.comp_pts: + group_nnz = (group_block != 0).sum(axis=0) + rg.pts[group_index, start:stop] = cp.asnumpy(group_nnz / n_group) + + +def _fill_basic_stats_from_accumulators( + rg: _RankGenes, + group_sums: cp.ndarray, + group_sq_sums: cp.ndarray, + group_nnz: cp.ndarray, + group_sizes: np.ndarray, + *, + n_cells: int, + compute_vars: bool, + total_sums: cp.ndarray | None = None, + total_sq_sums: cp.ndarray | None = None, + total_nnz: cp.ndarray | None = None, +) -> None: + n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] + means = group_sums / n + rg.means = cp.asnumpy(means) + if compute_vars: + group_ss = group_sq_sums - n * means**2 + rg.vars = cp.asnumpy(cp.maximum(group_ss / cp.maximum(n - 1, 1), 0)) + else: + rg.vars = np.zeros_like(rg.means) + rg.pts = cp.asnumpy(group_nnz / n) if rg.comp_pts else None + + n_rest = cp.float64(n_cells) - n + if total_sums is None: + total_sums = group_sums.sum(axis=0, keepdims=True) + rest_sums = total_sums - group_sums + rest_means = rest_sums / n_rest + rg.means_rest = cp.asnumpy(rest_means) + if compute_vars: + if total_sq_sums is None: + total_sq_sums = group_sq_sums.sum(axis=0, keepdims=True) + rest_ss = (total_sq_sums - group_sq_sums) - n_rest * rest_means**2 + rg.vars_rest = cp.asnumpy(cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0)) + else: + rg.vars_rest = np.zeros_like(rg.means_rest) + if rg.comp_pts: + if total_nnz is None: + total_nnz = group_nnz.sum(axis=0, keepdims=True) + rg.pts_rest = cp.asnumpy((total_nnz - group_nnz) / n_rest) + else: + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + +def _fill_ovo_stats_from_accumulators( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + group_sq_sums_slots: cp.ndarray, + group_nnz_slots: cp.ndarray, + *, + group_sizes: NDArray, + test_group_indices: list[int], + n_ref: int, + compute_vars: bool, +) -> None: + n_test = len(test_group_indices) + n_genes = int(group_sums_slots.shape[1]) + n_groups = len(rg.groups_order) + slot_group_indices = np.empty(n_test + 1, dtype=np.intp) + slot_group_indices[:n_test] = np.asarray(test_group_indices, dtype=np.intp) + slot_group_indices[n_test] = rg.ireference + slot_sizes = np.empty(n_test + 1, dtype=np.float64) + slot_sizes[:n_test] = group_sizes[slot_group_indices[:n_test]] + slot_sizes[n_test] = n_ref + slot_sizes_dev = cp.asarray(slot_sizes, dtype=cp.float64)[:, None] + + rg.means = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.vars = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.pts = np.zeros((n_groups, n_genes), dtype=np.float64) if rg.comp_pts else None + + means_slots = group_sums_slots / slot_sizes_dev + rg.means[slot_group_indices] = cp.asnumpy(means_slots) + if compute_vars: + group_ss = group_sq_sums_slots - slot_sizes_dev * means_slots**2 + denom = cp.maximum(slot_sizes_dev - 1.0, 1.0) + rg.vars[slot_group_indices] = cp.asnumpy(cp.maximum(group_ss / denom, 0)) + if rg.comp_pts: + rg.pts[slot_group_indices] = cp.asnumpy(group_nnz_slots / slot_sizes_dev) + + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + +def _ovo_logfoldchanges_from_sums( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + test_sizes: cp.ndarray, + n_ref: int, +) -> cp.ndarray: + n_test = int(test_sizes.shape[0]) + mean_group = group_sums_slots[:n_test] / test_sizes[:, None] + mean_ref = group_sums_slots[n_test][None, :] / cp.float64(n_ref) + if rg._log1p_base is not None: + scale = cp.float64(np.log(rg._log1p_base)) + group_expr = cp.expm1(mean_group * scale) + ref_expr = cp.expm1(mean_ref * scale) + else: + group_expr = cp.expm1(mean_group) + ref_expr = cp.expm1(mean_ref) + return cp.log2((group_expr + EPS) / (ref_expr + EPS)) + + +def _wilcoxon_scores( + rank_sums: cp.ndarray, + group_sizes: cp.ndarray, + z_scores: cp.ndarray, + *, + return_u_values: bool, +) -> cp.ndarray: + if not return_u_values: + return z_scores + n_group = group_sizes[:, None] + return rank_sums - n_group * (n_group + 1.0) / 2.0 + + +def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool): + is_f64 = X.data.dtype == np.float64 + is_idx64 = support_idx64 and X.indices.dtype == np.int64 + is_i64 = X.indptr.dtype == np.int64 + suffix = "" + if is_f64: + suffix += "_f64" + if is_idx64: + suffix += "_idx64" + if is_i64: + suffix += "_i64" + fn = getattr(module, base_name + suffix) + data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) + indices_arr = X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) + return fn, data_arr, indices_arr + + +def _device_sparse_arrays_i32_f32(X): + if X.indptr.dtype != cp.int32: + max_indptr = int(cp.asnumpy(X.indptr[-1])) + if max_indptr > np.iinfo(np.int32).max: + return None + data = X.data.astype(cp.float32, copy=False) + indices = X.indices.astype(cp.int32, copy=False) + indptr = X.indptr.astype(cp.int32, copy=False) + return data, indices, indptr + + +def _column_totals_for_host_matrix( + X, *, compute_sq_sums: bool, compute_nnz: bool +) -> tuple[cp.ndarray, cp.ndarray | None, cp.ndarray | None]: + n_cols = X.shape[1] + if isinstance(X, sp.spmatrix | sp.sparray): + data = np.asarray(X.data) + values = data.astype(np.float64, copy=False) + if X.format == "csc": + indptr = np.asarray(X.indptr) + counts = np.diff(indptr) + nonempty = counts > 0 + starts = indptr[:-1][nonempty] + sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sums[nonempty] = np.add.reduceat(values, starts) + sq_sums = None + if compute_sq_sums: + sq_sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sq_sums[nonempty] = np.add.reduceat(values * values, starts) + nnz = None + if compute_nnz: + nnz = np.zeros(n_cols, dtype=np.float64) + if starts.size: + nnz[nonempty] = np.add.reduceat( + (data != 0).astype(np.float64, copy=False), starts + ) + elif X.format == "csr": + indices = np.asarray(X.indices, dtype=np.intp) + sums = np.bincount(indices, weights=values, minlength=n_cols).astype( + np.float64, copy=False + ) + sq_sums = ( + np.bincount(indices, weights=values * values, minlength=n_cols).astype( + np.float64, copy=False + ) + if compute_sq_sums + else None + ) + nnz = ( + np.bincount( + indices, + weights=(data != 0).astype(np.float64, copy=False), + minlength=n_cols, + ).astype(np.float64, copy=False) + if compute_nnz + else None + ) + else: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + else: + raise TypeError(f"Unsupported host matrix type: {type(X)}") + + total_sums = cp.asarray(sums.reshape(1, n_cols), dtype=cp.float64) + total_sq_sums = ( + cp.asarray(sq_sums.reshape(1, n_cols), dtype=cp.float64) + if sq_sums is not None + else None + ) + total_nnz = ( + cp.asarray(nnz.reshape(1, n_cols), dtype=cp.float64) + if nnz is not None + else None + ) + return total_sums, total_sq_sums, total_nnz + + +def _host_ovr_totals_if_needed( + X, + group_codes: np.ndarray, + n_groups: int, + *, + compute_sq_sums: bool, + compute_nnz: bool, +) -> tuple[cp.ndarray | None, cp.ndarray | None, cp.ndarray | None]: + if not np.any(group_codes == n_groups): + return None, None, None + return _column_totals_for_host_matrix( + X, compute_sq_sums=compute_sq_sums, compute_nnz=compute_nnz + ) + + def wilcoxon( rg: _RankGenes, *, tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" # Compute basic stats - uses Aggregate if on GPU, else defers to chunks @@ -110,6 +414,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) # Compare each group against "rest" (all other cells) return _wilcoxon_vs_rest( @@ -121,6 +426,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) @@ -134,6 +440,7 @@ def _wilcoxon_vs_rest( tie_correct: bool, use_continuity: bool, chunk_size: int | None, + return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs rest of cells.""" n_groups = len(rg.groups_order) @@ -149,26 +456,203 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - # Build one-hot indicator matrix from group codes - codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) - group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) - valid_idx = cp.where(codes_gpu < n_groups)[0] - group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 + host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + if host_sparse: + if X.format not in {"csr", "csc"}: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + + group_codes = rg.group_codes.astype(np.int32, copy=False) + group_sizes_np = group_sizes.astype(np.float64, copy=False) + group_sizes_dev = cp.asarray(group_sizes_np, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + compute_vars = False + compute_nnz = rg.comp_pts + + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty( + (n_groups, n_total_genes) if compute_vars else (1, 1), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups, n_total_genes) if compute_nnz else (1, 1), + dtype=cp.float64, + ) + + if X.format == "csc": + csc = X + if not csc.has_sorted_indices: + csc = csc.copy() + csc.sort_indices() + csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovr_sparse_csc_host", csc, support_idx64=False + ) + csc_host_fn( + data_arr, + indices_arr, + csc.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVR_HOST_CSC_SUB_BATCH, + ) + else: + csr = X + if not csr.has_sorted_indices: + csr = csr.copy() + csr.sort_indices() + csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovr_sparse_csr_host", csr, support_idx64=True + ) + csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVR_HOST_CSR_SUB_BATCH, + ) + + if rg._compute_stats_in_chunks: + total_sums, total_sq_sums, total_nnz = _host_ovr_totals_if_needed( + X, + group_codes, + n_groups, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + ) + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes_np, + n_cells=n_cells, + compute_vars=compute_vars, + total_sums=total_sums, + total_sq_sums=total_sq_sums, + total_nnz=total_nnz, + ) + + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + variance *= (n_cells + 1) / 12.0 + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores_host = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ).get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + + if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): + sparse_arrays = _device_sparse_arrays_i32_f32(X) + if sparse_arrays is not None: + data, indices, indptr = sparse_arrays + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + if cpsp.isspmatrix_csc(X): + _wcs.ovr_sparse_csc_device( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSC_SUB_BATCH, + ) + else: + sparse_X = X + if not sparse_X.has_sorted_indices: + sparse_X = sparse_X.copy() + sparse_X.sort_indices() + data, indices, indptr = _device_sparse_arrays_i32_f32(sparse_X) + _wcs.ovr_sparse_csr_device( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, + ) + + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = ( + tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + ) + variance *= (n_cells + 1) / 12.0 + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores_host = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ).get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_matrix = None + if rg._compute_stats_in_chunks: + codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) + group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) + valid_idx = cp.where(codes_gpu < n_groups)[0] + group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev - chunk_width = _choose_chunk_size(chunk_size) + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) # Accumulate results per group all_scores: dict[int, list] = {i: [] for i in range(n_groups)} all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} - # One-time CSR->CSC via fast parallel Numba kernel; _get_column_block - # then uses direct indptr pointer copy for each chunk. - if isinstance(X, sp.spmatrix | sp.sparray): - X = _fast_csr_to_csc(X) if X.format == "csr" else X.tocsc() - for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) @@ -185,14 +669,28 @@ def _wilcoxon_vs_rest( n_cells=n_cells, ) - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) - else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - rank_sums = group_matrix.T @ ranks + block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) + sorter = cp.asfortranarray(cp.argsort(block_f32, axis=0).astype(cp.int32)) + sorted_vals = cp.asfortranarray(cp.take_along_axis(block_f32, sorter, axis=0)) + n_cols = stop - start + rank_sums = cp.zeros((n_groups, n_cols), dtype=cp.float64) + tie_corr = ( + cp.empty(n_cols, dtype=cp.float64) + if tie_correct + else cp.ones(n_cols, dtype=cp.float64) + ) + _wc.ovr_rank_dense( + sorted_vals, + sorter, + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_cols, + n_groups=n_groups, + compute_tie_corr=tie_correct, + stream=cp.cuda.get_current_stream().ptr, + ) expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] variance *= (n_cells + 1) / 12.0 @@ -203,12 +701,15 @@ def _wilcoxon_vs_rest( z = diff / std cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ) - z_host = z.get() + scores_host = scores.get() p_host = p_values.get() for idx in range(n_groups): - all_scores[idx].append(z_host[idx]) + all_scores[idx].append(scores_host[idx]) all_pvals[idx].append(p_host[idx]) # Collect results per group @@ -227,98 +728,379 @@ def _wilcoxon_with_reference( tie_correct: bool, use_continuity: bool, chunk_size: int | None, + return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs a specific reference group.""" + """Wilcoxon test: all selected groups vs a specific reference group.""" codes = rg.group_codes - n_ref = int(group_sizes[rg.ireference]) - mask_ref = codes == rg.ireference - - results: list[tuple[int, NDArray, NDArray]] = [] + n_groups = len(rg.groups_order) + ireference = rg.ireference + n_ref = int(group_sizes[ireference]) + ref_row_ids = np.flatnonzero(codes == ireference).astype(np.int32, copy=False) - for group_index in range(len(rg.groups_order)): - if group_index == rg.ireference: - continue + test_group_indices = [i for i in range(n_groups) if i != ireference] + if not test_group_indices: + return [] - n_group = int(group_sizes[group_index]) - n_combined = n_group + n_ref + offsets = [0] + row_id_parts = [] + small_groups = [] + for group_index in test_group_indices: + group_rows = np.flatnonzero(codes == group_index).astype(np.int32, copy=False) + row_id_parts.append(group_rows) + offsets.append(offsets[-1] + int(group_rows.size)) + if int(group_sizes[group_index]) <= MIN_GROUP_SIZE_WARNING: + small_groups.append(str(rg.groups_order[group_index])) - # Warn for small groups - if n_group <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: - warnings.warn( - f"Group {rg.groups_order[group_index]} has size {n_group} " - f"(reference {n_ref}); normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, + if n_ref <= MIN_GROUP_SIZE_WARNING or small_groups: + parts = [] + if small_groups: + parts.append( + f"{len(small_groups)} test group(s) have size " + f"<= {MIN_GROUP_SIZE_WARNING} (first few: " + f"{', '.join(small_groups[:5])}" + f"{'...' if len(small_groups) > 5 else ''})" ) + if n_ref <= MIN_GROUP_SIZE_WARNING: + parts.append(f"reference has size {n_ref}") + warnings.warn( + f"Small groups detected: {'; '.join(parts)}. normal approximation " + "of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=4, + ) - # Combined mask: group + reference - mask_obs = codes == group_index - mask_combined = mask_obs | mask_ref - - # Subset matrix ONCE before chunking (10x faster than filtering each chunk) - X_subset = X[mask_combined, :] + all_grp_row_ids = ( + np.concatenate(row_id_parts).astype(np.int32, copy=False) + if row_id_parts + else np.empty(0, dtype=np.int32) + ) + offsets_np = np.asarray(offsets, dtype=np.int32) + offsets_gpu = cp.asarray(offsets_np) + n_all_grp = int(all_grp_row_ids.size) + n_test = len(test_group_indices) + max_test_size = int(np.diff(offsets_np).max(initial=0)) + use_presorted_groups = max_test_size > OVO_SORT_GROUP_THRESHOLD + test_sizes = cp.asarray( + group_sizes[np.asarray(test_group_indices, dtype=np.intp)].astype( + np.float64, copy=False + ) + ) - # One-time CSR->CSC via fast parallel Numba kernel - if isinstance(X_subset, sp.spmatrix | sp.sparray): - X_subset = ( - _fast_csr_to_csc(X_subset) - if X_subset.format == "csr" - else X_subset.tocsc() + host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + if host_sparse: + if X.format not in {"csr", "csc"}: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." ) - # Within the combined array, True = group cell, False = reference cell - group_mask_gpu = cp.asarray(mask_obs[mask_combined]) - - chunk_width = _choose_chunk_size(chunk_size) - - # Pre-allocate output arrays - scores = np.empty(n_total_genes, dtype=np.float64) - pvals = np.empty(n_total_genes, dtype=np.float64) - - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + n_groups_stats = n_test + 1 + compute_vars = False + compute_sums = rg._compute_stats_in_chunks + compute_nnz = rg.comp_pts + group_sums = cp.empty( + (n_groups_stats, n_total_genes) + if (compute_sums or X.format == "csc") + else (1,), + dtype=cp.float64, + ) + group_sq_sums = cp.empty( + (n_groups_stats, n_total_genes) if compute_vars else (1,), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups_stats, n_total_genes) if compute_nnz else (1,), + dtype=cp.float64, + ) - # Get block for combined cells only - block = _get_column_block(X_subset, start, stop) + stats_code_lookup = np.full(n_groups + 1, n_groups_stats, dtype=np.int32) + test_group_indices_np = np.asarray(test_group_indices, dtype=np.intp) + stats_code_lookup[test_group_indices_np] = np.arange(n_test, dtype=np.int32) + stats_code_lookup[ireference] = n_test + stats_codes = stats_code_lookup[codes] - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_with_ref( - block, - start, - stop, - group_index=group_index, - group_mask_gpu=group_mask_gpu, - n_group=n_group, + if X.format == "csc": + csc = X + if not csc.has_sorted_indices: + csc = csc.copy() + csc.sort_indices() + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) + csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovo_streaming_csc_host", csc, support_idx64=True + ) + csc_host_fn( + data_arr, + indices_arr, + csc.indptr, + ref_row_map, + grp_row_map, + offsets_np, + stats_codes, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_rows=X.shape[0], + n_cols=n_total_genes, + n_groups=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, + ) + else: + csr = X + csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovo_streaming_csr_host", csr, support_idx64=True + ) + csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + ref_row_ids.astype(np.int32, copy=False), + all_grp_row_ids.astype(np.int32, copy=False), + offsets_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_full_rows=X.shape[0], n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_test=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + compute_sums=compute_sums, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, ) - # Ranks for combined group+reference cells - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) + logfoldchanges_gpu = None + if rg._compute_stats_in_chunks: + if rg._store_wilcoxon_gpu_result and not rg.comp_pts: + logfoldchanges_gpu = _ovo_logfoldchanges_from_sums( + rg, + group_sums, + test_sizes, + n_ref, + ) + rg._compute_stats_in_chunks = False else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) + _fill_ovo_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes=group_sizes, + test_group_indices=test_group_indices, + n_ref=n_ref, + compute_vars=compute_vars, + ) - # Rank sum for the group - rank_sums = (ranks * group_mask_gpu[:, None]).sum(axis=0) + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr_arr + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + logfoldchanges_gpu, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] + + if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): + sparse_X = X + if cpsp.isspmatrix_csr(sparse_X) and not sparse_X.has_sorted_indices: + sparse_X = sparse_X.copy() + sparse_X.sort_indices() + sparse_arrays = _device_sparse_arrays_i32_f32(sparse_X) + if sparse_arrays is not None: + data, indices, indptr = sparse_arrays + offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + + if cpsp.isspmatrix_csc(sparse_X): + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) + _wcs.ovo_streaming_csc_device( + data, + indices, + indptr, + cp.asarray(ref_row_map), + cp.asarray(grp_row_map), + offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) + else: + _wcs.ovo_streaming_csr_device( + data, + indices, + indptr, + cp.asarray(ref_row_ids, dtype=cp.int32), + cp.asarray(all_grp_row_ids, dtype=cp.int32), + offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) - # Wilcoxon z-score formula for two groups - expected = n_group * (n_combined + 1) / 2.0 - variance = tie_corr * n_group * n_ref * (n_combined + 1) / 12.0 - std = cp.sqrt(variance) + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr_arr diff = rank_sums - expected if use_continuity: diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std + z = diff / cp.sqrt(variance) cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + None, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] + + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + + scores_host = np.empty((n_test, n_total_genes), dtype=np.float64) + pvals_host = np.empty((n_test, n_total_genes), dtype=np.float64) + + for start in range(0, n_total_genes, chunk_width): + stop = min(start + chunk_width, n_total_genes) + n_cols = stop - start + + ref_block = _extract_dense_rows_cols(X, ref_row_ids, start, stop) + grp_block = _extract_dense_rows_cols(X, all_grp_row_ids, start, stop) + + _fill_ovo_chunk_stats( + rg, + ref_block, + grp_block, + offsets=offsets_np, + test_group_indices=test_group_indices, + start=start, + stop=stop, + group_sizes=group_sizes, + ) + + ref_sorted = cp.asfortranarray(cp.sort(ref_block.astype(cp.float32), axis=0)) + grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) + rank_sums = cp.empty((n_test, n_cols), dtype=cp.float64) + tie_corr = cp.empty((n_test, n_cols), dtype=cp.float64) + + if use_presorted_groups: + grp_rank_input = cp.empty_like(grp_f32) + for slot in range(n_test): + begin = int(offsets_np[slot]) + end = int(offsets_np[slot + 1]) + grp_rank_input[begin:end] = cp.sort(grp_f32[begin:end], axis=0) + grp_rank_input = cp.asfortranarray(grp_rank_input) + _wc.ovo_rank_presorted( + ref_sorted, + grp_rank_input, + offsets_gpu, + rank_sums, + tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + stream=cp.cuda.get_current_stream().ptr, + ) + else: + _wc.ovo_rank_dense( + ref_sorted, + grp_f32, + offsets_gpu, + rank_sums, + tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + stream=cp.cuda.get_current_stream().ptr, + ) - # Fill pre-allocated arrays - scores[start:stop] = z.get() - pvals[start:stop] = p_values.get() + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr + std = cp.sqrt(variance) + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / std + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) - results.append((group_index, scores, pvals)) + scores_host[:, start:stop] = scores.get() + pvals_host[:, start:stop] = p_values.get() - return results + return [ + (group_index, scores_host[slot], pvals_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index fa4bbccf..70d049af 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -102,7 +102,7 @@ def wilcoxon_binned( ``'log1p'`` uses a fixed [0, 15] range suitable for log1p-normalized data. ``'auto'`` computes the actual (min, max) of the data. Use this - for z-scored or unnormalized data. + for nonnegative expression data outside the fixed log1p range. """ if not rg.is_log1p: warnings.warn( @@ -119,20 +119,6 @@ def wilcoxon_binned( if n_bins is None: n_bins = _DASK_N_BINS if isinstance(X, DaskArray) else _DEFAULT_N_BINS - # Sparse kernels assume non-negative data (pre-fill+correct pattern). - # Dense kernel handles any range. - # NOTE: Dask sparse is not validated here because checking .data.min() - # would require materializing all blocks. The sparse histogram kernels - # will silently produce incorrect results for negative Dask sparse data. - if not isinstance(X, DaskArray) and cpsp.issparse(X) and X.nnz > 0: - if float(X.data.min()) < 0: - msg = ( - "Sparse input contains negative values. The sparse histogram " - "kernels assume non-negative data. Convert to dense or use " - "bin_range='auto' with a dense array." - ) - raise ValueError(msg) - n_groups = len(rg.groups_order) n_cells, n_genes = X.shape group_sizes = rg.group_sizes diff --git a/tests/test_rank_genes_groups_ttest.py b/tests/test_rank_genes_groups_ttest.py index 8fe93ae7..24a40721 100644 --- a/tests/test_rank_genes_groups_ttest.py +++ b/tests/test_rank_genes_groups_ttest.py @@ -20,6 +20,7 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") if sparse: + adata_gpu.X = np.abs(adata_gpu.X).astype(np.float32) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() @@ -52,12 +53,19 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] + rtol = 1e-6 if sparse else 1e-13 + if sparse and field in {"scores", "logfoldchanges"}: + atol = 1e-6 + elif sparse: + atol = 1e-12 + else: + atol = 1e-15 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) np.testing.assert_allclose( - gpu_values, cpu_values, rtol=1e-13, atol=1e-15, equal_nan=True + gpu_values, cpu_values, rtol=rtol, atol=atol, equal_nan=True ) params = gpu_result["params"] diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 0c6844da..87030dfb 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1,6 +1,7 @@ from __future__ import annotations import cupy as cp +import cupyx.scipy.sparse as cpsp import numpy as np import pandas as pd import pytest @@ -11,6 +12,177 @@ import rapids_singlecell as rsc +def _to_format(X_dense, fmt): + if fmt == "numpy_dense": + return np.asarray(X_dense) + if fmt == "scipy_csr": + return sp.csr_matrix(X_dense) + if fmt == "scipy_csc": + return sp.csc_matrix(X_dense) + if fmt == "cupy_dense": + return cp.asarray(X_dense) + if fmt == "cupy_csr": + return cpsp.csr_matrix(cp.asarray(X_dense)) + if fmt == "cupy_csc": + return cpsp.csc_matrix(cp.asarray(X_dense)) + raise ValueError(f"Unknown format: {fmt}") + + +def _make_nonnegative(adata): + adata.X = np.abs(np.asarray(adata.X)).astype(np.float32) + return adata + + +@pytest.mark.parametrize( + "method", + ["t-test", "t-test_overestim_var", "wilcoxon", "wilcoxon_binned", "logreg"], +) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_rank_genes_groups_sparse_negative_values_raise(method, fmt): + X = np.array( + [ + [-1.0, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + ], + dtype=np.float32, + ) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame( + {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} + ), + var=pd.DataFrame(index=["g0", "g1", "g2"]), + ) + + with pytest.raises(ValueError, match="Sparse input contains negative values"): + rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) + + +def test_rank_genes_groups_default_lazy_get_df_matches_scanpy(): + np.random.seed(42) + adata_lazy = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=120) + _make_nonnegative(adata_lazy) + adata_lazy.obs["blobs"] = adata_lazy.obs["blobs"].astype("category") + adata_lazy.X = sp.csr_matrix(adata_lazy.X) + adata_cpu = adata_lazy.copy() + adata_cpu.X = adata_cpu.X.toarray() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "reference": "1", + "use_raw": False, + "tie_correct": True, + "n_genes": 4, + } + rsc.tl.rank_genes_groups(adata_lazy, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + lazy_result = adata_lazy.uns["rank_genes_groups"] + assert lazy_result["names"].dtype.names == ("0", "2") + assert tuple(lazy_result["names"][0]) == tuple( + adata_cpu.uns["rank_genes_groups"]["names"][0] + ) + np.testing.assert_array_equal( + lazy_result["names"].copy(), + np.asarray(lazy_result["names"]), + ) + + lazy_df = sc.get.rank_genes_groups_df(adata_lazy, group=None) + scanpy_df = sc.get.rank_genes_groups_df(adata_cpu, group=None) + pd.testing.assert_frame_equal(lazy_df, scanpy_df) + + +def test_rank_genes_groups_return_format_removed(): + adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) + _make_nonnegative(adata) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(TypeError, match="return_format has been removed"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + return_format="arrays", + ) + + +@pytest.mark.parametrize("reference", ["rest", "b"]) +@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_csr"]) +def test_rank_genes_groups_wilcoxon_return_u_values(reference, fmt): + X = np.array( + [ + [5.0, 0.0, 1.0, 2.0], + [4.0, 0.0, 1.0, 2.0], + [1.0, 3.0, 2.0, 2.0], + [0.0, 2.0, 2.0, 2.0], + [2.0, 1.0, 0.0, 3.0], + [3.0, 1.0, 0.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "b", "b", "c", "c"]) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference=reference, + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + result = adata.uns["rank_genes_groups"] + assert result["params"]["return_u_values"] is True + assert result["scores"].dtype["a"] == np.dtype("float64") + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + mask_group = labels == "a" + mask_ref = labels != "a" if reference == "rest" else labels == reference + expected = np.array( + [ + mannwhitneyu( + X[mask_group, gene], + X[mask_ref, gene], + alternative="two-sided", + ).statistic + for gene in range(X.shape[1]) + ], + dtype=np.float64, + ) + + gene_to_idx = {name: idx for idx, name in enumerate(adata.var_names)} + expected_sorted = np.array([expected[gene_to_idx[name]] for name in df["names"]]) + np.testing.assert_allclose(df["scores"].to_numpy(), expected_sorted) + + +def test_rank_genes_groups_return_u_values_requires_wilcoxon(): + adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) + _make_nonnegative(adata) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(ValueError, match="only supported for method='wilcoxon'"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="t-test", + use_raw=False, + return_u_values=True, + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("tie_correct", [True, False]) @pytest.mark.parametrize("sparse", [True, False]) @@ -21,6 +193,7 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") if sparse: + _make_nonnegative(adata_gpu) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() @@ -55,11 +228,13 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] + rtol = 1e-6 if field == "logfoldchanges" else 1e-13 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) - np.testing.assert_allclose(gpu_values, cpu_values, rtol=1e-13, atol=1e-15) + atol = 1e-6 if field == "logfoldchanges" else 1e-15 + np.testing.assert_allclose(gpu_values, cpu_values, rtol=rtol, atol=atol) params = gpu_result["params"] assert params["use_raw"] is False @@ -148,6 +323,230 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): assert np.all(adjusted <= 1.0) +def test_rank_genes_groups_wilcoxon_skip_empty_groups_filters_singletons(): + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=21) + adata.obs["target"] = pd.Categorical( + ["ref"] * 10 + ["valid"] * 10 + ["singleton"], + categories=["ref", "valid", "singleton", "empty"], + ) + + rsc.tl.rank_genes_groups( + adata, + "target", + method="wilcoxon", + reference="ref", + use_raw=False, + n_genes=3, + skip_empty_groups=True, + ) + + result = adata.uns["rank_genes_groups"] + assert result["names"].dtype.names == ("valid",) + assert result["scores"].dtype.names == ("valid",) + + +def test_rank_genes_groups_wilcoxon_skip_empty_groups_all_tests_filtered(): + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=11) + adata.obs["target"] = pd.Categorical( + ["ref"] * 10 + ["singleton"], + categories=["ref", "singleton", "empty"], + ) + + rsc.tl.rank_genes_groups( + adata, + "target", + method="wilcoxon", + reference="ref", + use_raw=False, + skip_empty_groups=True, + ) + + result = adata.uns["rank_genes_groups"] + assert "names" not in result + assert result["params"]["reference"] == "ref" + + +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csr", id="host_csr"), + pytest.param("scipy_csc", id="host_csc"), + pytest.param("cupy_dense", id="device_dense"), + ], +) +def test_wilcoxon_subset_rest_stats_match_scanpy(fmt): + """groups=... with reference='rest' must use all other cells for stats.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=160) + _make_nonnegative(adata_gpu) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "pts": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + rtol = 1e-6 if field == "logfoldchanges" else 1e-13 + atol = 1e-6 if field == "logfoldchanges" else 1e-15 + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + for key in ("pts", "pts_rest"): + gpu_pts = gpu_result[key] + cpu_pts = cpu_result[key] + for col in gpu_pts.columns: + np.testing.assert_allclose( + gpu_pts[col].values, cpu_pts[col].values, rtol=1e-13, atol=1e-15 + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +def test_wilcoxon_zero_nnz_host_sparse_does_not_crash(reference, fmt): + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["0"] * 4 + ["1"] * 4 + ["2"] * 4, + categories=["0", "1", "2"], + ) + } + ) + adata = sc.AnnData( + X=_to_format(np.zeros((12, 5), dtype=np.float32), fmt), + obs=obs, + var=pd.DataFrame(index=[f"g{i}" for i in range(5)]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + pts=True, + ) + + result = adata.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + assert np.all(np.isfinite(np.asarray(result[field][group], dtype=float))) + + +def test_wilcoxon_ovo_host_csr_unsorted_indices_match_sorted(): + rng = np.random.default_rng(42) + dense = rng.poisson(1.0, size=(80, 12)).astype(np.float32) + dense[rng.random(dense.shape) < 0.55] = 0 + sorted_csr = sp.csr_matrix(dense) + unsorted_csr = sorted_csr.copy() + for row in range(unsorted_csr.shape[0]): + start, stop = unsorted_csr.indptr[row : row + 2] + order = np.arange(stop - start)[::-1] + unsorted_csr.indices[start:stop] = unsorted_csr.indices[start:stop][order] + unsorted_csr.data[start:stop] = unsorted_csr.data[start:stop][order] + unsorted_csr.has_sorted_indices = False + + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["ref"] * 20 + ["a"] * 20 + ["b"] * 20 + ["c"] * 20, + categories=["ref", "a", "b", "c"], + ) + } + ) + var = pd.DataFrame(index=[f"g{i}" for i in range(dense.shape[1])]) + sorted_adata = sc.AnnData(X=sorted_csr, obs=obs.copy(), var=var.copy()) + unsorted_adata = sc.AnnData(X=unsorted_csr, obs=obs.copy(), var=var.copy()) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "reference": "ref", + "use_raw": False, + "tie_correct": True, + "n_genes": dense.shape[1], + } + rsc.tl.rank_genes_groups(sorted_adata, **kw) + rsc.tl.rank_genes_groups(unsorted_adata, **kw) + + sorted_result = sorted_adata.uns["rank_genes_groups"] + unsorted_result = unsorted_adata.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + for group in sorted_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(unsorted_result[field][group], dtype=float), + np.asarray(sorted_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize( + "fmt", + [ + "numpy_dense", + "scipy_csr", + "scipy_csc", + "cupy_dense", + "cupy_csr", + "cupy_csc", + ], +) +def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt): + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=120) + _make_nonnegative(adata_gpu) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + "n_genes": 5, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + rtol = 1e-6 if field == "logfoldchanges" else 1e-13 + atol = 1e-6 if field == "logfoldchanges" else 1e-15 + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + @pytest.mark.parametrize( "reference_before,reference_after", [("rest", "rest"), ("1", "One")], From 4094e6be4a54dc72b52ca1adb1e12267ac3e5c1b Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 17:07:29 +0200 Subject: [PATCH 02/36] add rmm --- CMakeLists.txt | 70 +++++++++++++++++++ notebooks | 2 +- pyproject.toml | 18 ++++- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 52 +++++++------- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 4 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 4 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 46 ++++++------ .../_cuda/wilcoxon/wilcoxon_rmm.cu | 20 ++++++ .../_cuda/wilcoxon/wilcoxon_sparse.cu | 24 +++++-- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 5 +- .../tools/_rank_genes_groups/_wilcoxon.py | 2 +- 11 files changed, 181 insertions(+), 66 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 85d33e91..67d8090c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,74 @@ if (RSC_BUILD_EXTENSIONS) find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) find_package(nanobind CONFIG REQUIRED) find_package(CUDAToolkit REQUIRED) + set(RSC_RMM_HINTS) + set(RSC_RAPIDS_CMAKE_PREFIXES) + set(RSC_CCCL_HINTS) + set(RSC_RAPIDS_LOGGER_HINTS) + set(RSC_NVTX3_HINTS) + macro(_rsc_collect_rapids_python_prefix _rsc_prefix) + if (NOT "${_rsc_prefix}" STREQUAL "") + file(GLOB _rsc_rmm_dirs "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/cmake/rmm") + file(GLOB _rsc_rapids_prefixes + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64" + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/rapids" + "${_rsc_prefix}/lib/python*/site-packages/rapids_logger/lib64" + "${_rsc_prefix}/lib/python*/site-packages/nvidia/cu*/lib" + ) + file(GLOB _rsc_cccl_dirs + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/rapids/cmake/cccl" + "${_rsc_prefix}/lib/python*/site-packages/nvidia/cu*/lib/cmake/cccl" + ) + file(GLOB _rsc_rapids_logger_dirs "${_rsc_prefix}/lib/python*/site-packages/rapids_logger/lib64/cmake/rapids_logger") + file(GLOB _rsc_nvtx3_dirs "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/cmake/nvtx3") + list(APPEND RSC_RMM_HINTS ${_rsc_rmm_dirs}) + list(APPEND RSC_RAPIDS_CMAKE_PREFIXES ${_rsc_rapids_prefixes}) + list(APPEND RSC_CCCL_HINTS ${_rsc_cccl_dirs}) + list(APPEND RSC_RAPIDS_LOGGER_HINTS ${_rsc_rapids_logger_dirs}) + list(APPEND RSC_NVTX3_HINTS ${_rsc_nvtx3_dirs}) + endif() + endmacro() + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import importlib.util, pathlib; spec = importlib.util.find_spec('librmm'); print(pathlib.Path(spec.origin).parent / 'lib64' / 'cmake' / 'rmm' if spec else '')" + OUTPUT_VARIABLE RSC_PYTHON_RMM_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") + list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}") + endif() + foreach(_rsc_python_prefix IN ITEMS "${Python_ROOT_DIR}" "${Python3_ROOT_DIR}") + _rsc_collect_rapids_python_prefix("${_rsc_python_prefix}") + endforeach() + foreach(_rsc_env_prefix IN ITEMS "$ENV{CONDA_PREFIX}" "$ENV{VIRTUAL_ENV}") + _rsc_collect_rapids_python_prefix("${_rsc_env_prefix}") + endforeach() + string(REPLACE ":" ";" _rsc_path_entries "$ENV{PATH}") + foreach(_rsc_path_entry IN LISTS _rsc_path_entries) + get_filename_component(_rsc_path_prefix "${_rsc_path_entry}/.." ABSOLUTE) + _rsc_collect_rapids_python_prefix("${_rsc_path_prefix}") + endforeach() + if (RSC_RAPIDS_CMAKE_PREFIXES) + list(APPEND CMAKE_PREFIX_PATH ${RSC_RAPIDS_CMAKE_PREFIXES}) + if (RSC_CCCL_HINTS) + list(GET RSC_CCCL_HINTS 0 _rsc_cccl_dir) + set(CCCL_DIR "${_rsc_cccl_dir}" CACHE PATH "Path to CCCL package config" FORCE) + endif() + if (RSC_RAPIDS_LOGGER_HINTS) + list(GET RSC_RAPIDS_LOGGER_HINTS 0 _rsc_rapids_logger_dir) + set(rapids_logger_DIR "${_rsc_rapids_logger_dir}" CACHE PATH "Path to rapids_logger package config" FORCE) + endif() + if (RSC_NVTX3_HINTS) + list(GET RSC_NVTX3_HINTS 0 _rsc_nvtx3_dir) + set(nvtx3_DIR "${_rsc_nvtx3_dir}" CACHE PATH "Path to nvtx3 package config" FORCE) + endif() + endif() + if (RSC_RMM_HINTS) + find_package(rmm CONFIG REQUIRED HINTS ${RSC_RMM_HINTS}) + else() + find_package(rmm CONFIG REQUIRED) + endif() + message(STATUS "Using RMM for CUDA extension scratch allocations") message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") @@ -86,6 +154,8 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) + target_sources(_wilcoxon_sparse_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) + target_link_libraries(_wilcoxon_sparse_cuda PRIVATE rmm::rmm) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) diff --git a/notebooks b/notebooks index 4cdaa44f..e5c97b34 160000 --- a/notebooks +++ b/notebooks @@ -1 +1 @@ -Subproject commit 4cdaa44fbd93b6f812fc8d2c72b89180ef92047d +Subproject commit e5c97b34f4acbf919fb3118c987cc5893e5b5fdf diff --git a/pyproject.toml b/pyproject.toml index c38e1d00..dc69471a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,22 @@ dependencies = [ ] [project.optional-dependencies] -rapids-cu13 = [ "cupy-cuda13x", "cudf-cu13>=25.10", "cuml-cu13>=25.10", "cugraph-cu13>=25.10", "cuvs-cu13>=25.10" ] -rapids-cu12 = [ "cupy-cuda12x", "cudf-cu12>=25.10", "cuml-cu12>=25.10", "cugraph-cu12>=25.10", "cuvs-cu12>=25.10" ] +rapids-cu13 = [ + "cupy-cuda13x", + "cudf-cu13>=25.10", + "cuml-cu13>=25.10", + "cugraph-cu13>=25.10", + "cuvs-cu13>=25.10", + "rmm-cu13>=25.10", +] +rapids-cu12 = [ + "cupy-cuda12x", + "cudf-cu12>=25.10", + "cuml-cu12>=25.10", + "cugraph-cu12>=25.10", + "cuvs-cu12>=25.10", + "rmm-cu12>=25.10", +] doc = [ "sphinx>=4.5.0", diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index dd50d2cb..2d5b3f2c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -9,6 +9,9 @@ #include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR +void* wilcoxon_rmm_allocate(size_t bytes); +void wilcoxon_rmm_deallocate(void* ptr, size_t bytes); + constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; constexpr int N_STREAMS = 4; @@ -93,50 +96,45 @@ struct HostRegisterGuard { }; // --------------------------------------------------------------------------- -// Small allocation pool for temporary CUDA buffers. The previous PR used RMM -// here, but these sparse Wilcoxon kernels only need scoped scratch memory; -// using cudaMalloc keeps this module independent of an extra build-time -// dependency. +// Small allocation pool for temporary CUDA buffers. Uses the current RMM device +// resource so scratch participates in the same pool as CuPy/RAPIDS allocations. // --------------------------------------------------------------------------- -struct RmmPool { - std::vector bufs; - - ~RmmPool() { - for (void* ptr : bufs) { - if (ptr) cudaFree(ptr); +struct RmmScratchPool { + struct Allocation { + void* ptr = nullptr; + size_t bytes = 0; + }; + std::vector bufs; + + ~RmmScratchPool() { + for (Allocation alloc : bufs) { + if (!alloc.ptr) continue; + wilcoxon_rmm_deallocate(alloc.ptr, alloc.bytes); } } template T* alloc(size_t count) { if (count == 0) count = 1; - void* ptr = nullptr; - cudaError_t err = cudaMalloc(&ptr, count * sizeof(T)); - if (err != cudaSuccess) { - throw std::runtime_error( - std::string("cudaMalloc failed in Wilcoxon scratch pool: ") + - cudaGetErrorString(err)); - } - bufs.push_back(ptr); + size_t bytes = count * sizeof(T); + void* ptr = wilcoxon_rmm_allocate(bytes); + bufs.push_back({ptr, bytes}); return static_cast(ptr); } }; struct ScopedCudaBuffer { void* ptr = nullptr; + size_t bytes = 0; - explicit ScopedCudaBuffer(size_t bytes) { - if (bytes == 0) bytes = 1; - cudaError_t err = cudaMalloc(&ptr, bytes); - if (err != cudaSuccess) { - throw std::runtime_error( - std::string("cudaMalloc failed in Wilcoxon scoped buffer: ") + - cudaGetErrorString(err)); - } + explicit ScopedCudaBuffer(size_t requested_bytes) { + bytes = requested_bytes == 0 ? 1 : requested_bytes; + ptr = wilcoxon_rmm_allocate(bytes); } ~ScopedCudaBuffer() { - if (ptr) cudaFree(ptr); + if (!ptr) return; + wilcoxon_rmm_deallocate(ptr, bytes); } void* data() { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 7ad20b01..b195bee0 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -59,7 +59,7 @@ static void ovo_streaming_csr_impl( } if (ref_cache_cols < 1) ref_cache_cols = 1; - RmmPool pool; + RmmScratchPool pool; size_t cub_temp_bytes = 0; if (needs_tier3) { @@ -340,7 +340,7 @@ static void ovo_streaming_csc_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - RmmPool pool; + RmmScratchPool pool; int* d_sort_group_ids = nullptr; if (needs_tier3) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index feb86e57..11827b0a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -73,7 +73,7 @@ static void ovo_streaming_csc_host_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - RmmPool pool; + RmmScratchPool pool; int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); @@ -470,7 +470,7 @@ static void ovo_streaming_csr_host_impl( size_t max_sub_items = (size_t)max_pack_items; if (max_pack_rows == 0) return; - RmmPool pool; + RmmScratchPool pool; // Zero stats outputs. if (compute_sums) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 0f74a2c8..6eae2a28 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -7,9 +7,9 @@ * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. */ -template +template static void ovr_sparse_csc_host_streaming_impl( - const InT* h_data, const int* h_indices, const IndptrT* h_indptr, + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, double* d_group_nnz, int n_rows, int n_cols, int n_groups, @@ -33,7 +33,7 @@ static void ovr_sparse_csc_host_streaming_impl( size_t cub_temp_bytes = 0; if (max_nnz > 0) { auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); @@ -42,16 +42,16 @@ static void ovr_sparse_csc_host_streaming_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - RmmPool pool; + RmmScratchPool pool; int* d_group_codes = pool.alloc(n_rows); double* d_group_sizes = pool.alloc(n_groups); struct StreamBuf { InT* d_sparse_data_orig; float* d_sparse_data_f32; - int* d_sparse_indices; + IndexT* d_sparse_indices; int* d_seg_offsets; float* keys_out; - int* vals_out; + IndexT* vals_out; uint8_t* cub_temp; double* d_rank_sums; double* d_tie_corr; @@ -64,10 +64,10 @@ static void ovr_sparse_csc_host_streaming_impl( for (int s = 0; s < n_streams; s++) { bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); - bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].keys_out = pool.alloc(max_nnz); - bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); bufs[s].d_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); @@ -128,8 +128,8 @@ static void ovr_sparse_csc_host_streaming_impl( size_t total_nnz = (size_t)h_indptr[n_cols]; HostRegisterGuard _pin_data(const_cast(h_data), total_nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - total_nnz * sizeof(int)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); cudaDeviceSynchronize(); @@ -151,7 +151,7 @@ static void ovr_sparse_csc_host_streaming_impl( (size_t)batch_nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, - (size_t)batch_nnz * sizeof(int), + (size_t)batch_nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); } @@ -161,7 +161,7 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpyDeviceToDevice, stream); // Cast to float32 for sort + accumulate stats in float64 - launch_ovr_cast_and_accumulate_sparse( + launch_ovr_cast_and_accumulate_sparse( buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_seg_offsets, d_group_codes, buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups, @@ -187,10 +187,12 @@ static void ovr_sparse_csc_host_streaming_impl( (size_t)n_groups * sb_cols * sizeof(double), stream); } - rank_sums_sparse_ovr_kernel<<>>( - buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, - d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, - n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + rank_sums_sparse_ovr_kernel + <<>>( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, + rank_use_gmem); CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); // D2D: scatter sub-batch results into caller's GPU buffers @@ -257,7 +259,7 @@ static void ovr_sparse_csr_host_streaming_impl( int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; - RmmPool pool; + RmmScratchPool pool; size_t total_nnz = (size_t)h_indptr[n_rows]; // ---- Phase 0: CPU planning in native CSR order ---- @@ -466,7 +468,7 @@ static void ovr_sparse_csr_host_streaming_impl( (size_t)n_groups * sb_cols * sizeof(double), stream); } - rank_sums_sparse_ovr_kernel<<>>( + rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, @@ -558,7 +560,7 @@ static void ovr_sparse_csc_streaming_impl( bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); - RmmPool pool; + RmmScratchPool pool; struct StreamBuf { float* keys_out; int* vals_out; @@ -626,7 +628,7 @@ static void ovr_sparse_csc_streaming_impl( (size_t)n_groups * sb_cols * sizeof(double), stream); } - rank_sums_sparse_ovr_kernel<<>>( + rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); @@ -681,7 +683,7 @@ static void ovr_sparse_csr_streaming_impl( if (n_rows == 0 || n_cols == 0) return; // ---- Phase 0: Planning — count nnz per column via histogram ---- - RmmPool pool; + RmmScratchPool pool; int* d_col_counts = pool.alloc(n_cols); cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); { @@ -829,7 +831,7 @@ static void ovr_sparse_csr_streaming_impl( (size_t)n_groups * sb_cols * sizeof(double), stream); } - rank_sums_sparse_ovr_kernel<<>>( + rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu new file mode 100644 index 00000000..26e37f42 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu @@ -0,0 +1,20 @@ +#include +#include +#include + +#include +#include + +void* wilcoxon_rmm_allocate(size_t bytes) { + try { + return rmm::mr::get_current_device_resource()->allocate_sync(bytes); + } catch (std::exception const& e) { + throw std::runtime_error( + std::string("RMM allocation failed in Wilcoxon scratch: ") + + e.what()); + } +} + +void wilcoxon_rmm_deallocate(void* ptr, size_t bytes) { + rmm::mr::get_current_device_resource()->deallocate_sync(ptr, bytes); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index 19f1ef57..4316d284 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -61,10 +61,10 @@ void register_sparse_bindings(nb::module_& m) { "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); -#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndptrT) \ +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ m.def( \ NAME, \ - [](host_array h_data, host_array h_indices, \ + [](host_array h_data, host_array h_indices, \ host_array h_indptr, \ host_array h_group_codes, \ host_array h_group_sizes, \ @@ -75,7 +75,7 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_group_nnz, int n_rows, int n_cols, \ int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ bool compute_nnz, int sub_batch_cols) { \ - ovr_sparse_csc_host_streaming_impl( \ + ovr_sparse_csc_host_streaming_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_group_codes.data(), h_group_sizes.data(), \ d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ @@ -90,11 +90,21 @@ void register_sparse_bindings(nb::module_& m) { "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ "sub_batch_cols"_a = SUB_BATCH_COLS) - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int64_t); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_idx64", float, int64_t, + int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_idx64_i64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int, + int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, int, + int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_idx64", double, + int64_t, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_idx64_i64", double, + int64_t, int64_t); #undef RSC_OVR_SPARSE_CSC_HOST_BINDING #define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index b0e40fdc..603c1c96 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -172,9 +172,10 @@ __global__ void rank_sums_from_sorted_kernel( * * Grid: (sb_cols,) Block: (tpb,) */ +template __global__ void rank_sums_sparse_ovr_kernel( const float* __restrict__ sorted_vals, - const int* __restrict__ sorted_row_idx, + const IndexT* __restrict__ sorted_row_idx, const int* __restrict__ col_seg_offsets, const int* __restrict__ group_codes, const double* __restrict__ group_sizes, double* __restrict__ rank_sums, double* __restrict__ tie_corr, @@ -188,7 +189,7 @@ __global__ void rank_sums_sparse_ovr_kernel( int nnz_stored = seg_end - seg_start; const float* sv = sorted_vals + seg_start; - const int* si = sorted_row_idx + seg_start; + const IndexT* si = sorted_row_idx + seg_start; extern __shared__ double smem[]; double* grp_sums; diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index e20af614..90c54eb2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -489,7 +489,7 @@ def _wilcoxon_vs_rest( csc = csc.copy() csc.sort_indices() csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovr_sparse_csc_host", csc, support_idx64=False + _wcs, "ovr_sparse_csc_host", csc, support_idx64=True ) csc_host_fn( data_arr, From 9c391ed369c347ad3ae4ea0fb4c8a2c169b33829 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 17:42:23 +0200 Subject: [PATCH 03/36] update publish and cmake --- .github/workflows/publish.yml | 52 ++++- CMakeLists.txt | 20 ++ pyproject.toml | 9 +- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 141 ------------- .../_cuda/wilcoxon/wilcoxon.cu | 48 ----- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 9 - .../wilcoxon/wilcoxon_sparse_kernels.cuh | 145 -------------- .../tools/_rank_genes_groups/_wilcoxon.py | 65 ------ tests/test_rank_genes_groups_wilcoxon.py | 187 +----------------- 9 files changed, 74 insertions(+), 602 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3f2e4447..4ca5d522 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -69,16 +69,47 @@ jobs: path = pathlib.Path("pyproject.toml") text = path.read_text() + def remove_toml_array(text, key): + lines = text.splitlines(keepends=True) + out = [] + i = 0 + while i < len(lines): + if lines[i].startswith(f"{key} = ["): + depth = lines[i].count("[") - lines[i].count("]") + i += 1 + while i < len(lines) and depth > 0: + depth += lines[i].count("[") - lines[i].count("]") + i += 1 + continue + out.append(lines[i]) + i += 1 + return "".join(out) + # Rename package text = text.replace( 'name = "rapids-singlecell"', f'name = "rapids-singlecell-cu{cuda}"', ) # Rename matching extra to "rapids", remove the other - text = text.replace(f'rapids-cu{cuda} =', 'rapids =') - # Remove the other CUDA extra line entirely - lines = text.splitlines(keepends=True) - text = "".join(l for l in lines if f'rapids-cu{other}' not in l) + text = text.replace(f'rapids-cu{cuda} = [', 'rapids = [') + text = remove_toml_array(text, f"rapids-cu{other}") + + # librmm is needed at build time because CMake links the CUDA + # extension against librmm. Add the matching wheel to the isolated + # PEP 517 build requirements after selecting the CUDA package variant. + for dep in ( + f' "librmm-cu{other}>=25.10",\n', + f' "rmm-cu{other}>=25.10",\n', + ): + text = text.replace(dep, "") + rmm_build_req = f' "librmm-cu{cuda}>=25.10",\n' + build_system_text = text.split("[project]", 1)[0] + if f'"librmm-cu{cuda}>=25.10"' not in build_system_text: + text = text.replace( + ']\nbuild-backend = "scikit_build_core.build"', + f'{rmm_build_req}]\nbuild-backend = "scikit_build_core.build"', + 1, + ) # Set CUDA architectures (replace "native" with CI target archs) text = text.replace( @@ -96,6 +127,7 @@ jobs: - name: Sanity check pyproject.toml run: | + python3 -c "import tomllib; tomllib.load(open('pyproject.toml', 'rb'))" grep -E "name|rapids|CUDA_ARCH" pyproject.toml - name: Build CUDA manylinux image @@ -117,9 +149,19 @@ jobs: CIBW_BEFORE_BUILD: > python -m pip install -U pip scikit-build-core cmake ninja nanobind + librmm-cu${{ matrix.cuda_major }} && + RMM_ROOT=$(python -c "import librmm; print(librmm.__path__[0])") && + LOG_ROOT=$(python -c "import rapids_logger; print(rapids_logger.__path__[0])") && + echo "[rsc-build] librmm=$RMM_ROOT" && + echo "[rsc-build] rapids_logger=$LOG_ROOT" && + ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so && + ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && + ldconfig && + python -c "import librmm; print(librmm.__path__[0])" > /tmp/.librmm_dir && + echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} -w {dest_dir} {wheel}" + CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" CIBW_BUILD_VERBOSITY: "1" - uses: actions/upload-artifact@v4 diff --git a/CMakeLists.txt b/CMakeLists.txt index 67d8090c..85fcfc2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,26 @@ if (RSC_BUILD_EXTENSIONS) if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}") endif() + if(EXISTS "/tmp/.librmm_dir") + file(READ "/tmp/.librmm_dir" _rsc_librmm_marker) + string(STRIP "${_rsc_librmm_marker}" _rsc_librmm_marker) + file(GLOB _rsc_marker_rmm_dirs "${_rsc_librmm_marker}/lib64/cmake/rmm") + file(GLOB _rsc_marker_rapids_prefixes + "${_rsc_librmm_marker}/lib64" + "${_rsc_librmm_marker}/lib64/rapids" + "${_rsc_librmm_marker}/../rapids_logger/lib64" + ) + file(GLOB _rsc_marker_cccl_dirs + "${_rsc_librmm_marker}/lib64/rapids/cmake/cccl" + ) + file(GLOB _rsc_marker_rapids_logger_dirs "${_rsc_librmm_marker}/../rapids_logger/lib64/cmake/rapids_logger") + file(GLOB _rsc_marker_nvtx3_dirs "${_rsc_librmm_marker}/lib64/cmake/nvtx3") + list(APPEND RSC_RMM_HINTS ${_rsc_marker_rmm_dirs}) + list(APPEND RSC_RAPIDS_CMAKE_PREFIXES ${_rsc_marker_rapids_prefixes}) + list(APPEND RSC_CCCL_HINTS ${_rsc_marker_cccl_dirs}) + list(APPEND RSC_RAPIDS_LOGGER_HINTS ${_rsc_marker_rapids_logger_dirs}) + list(APPEND RSC_NVTX3_HINTS ${_rsc_marker_nvtx3_dirs}) + endif() foreach(_rsc_python_prefix IN ITEMS "${Python_ROOT_DIR}" "${Python3_ROOT_DIR}") _rsc_collect_rapids_python_prefix("${_rsc_python_prefix}") endforeach() diff --git a/pyproject.toml b/pyproject.toml index dc69471a..a3b07ede 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,9 @@ requires = [ "scikit-build-core>=0.10", "nanobind>=2.0.0", "setuptools-scm>=8", + # librmm headers/CMake config are needed at build time for Wilcoxon. + # CUDA wheel builds rewrite this to the matching cu12/cu13 package. + "librmm-cu12>=25.10", ] build-backend = "scikit_build_core.build" @@ -38,7 +41,7 @@ rapids-cu13 = [ "cuml-cu13>=25.10", "cugraph-cu13>=25.10", "cuvs-cu13>=25.10", - "rmm-cu13>=25.10", + "librmm-cu13>=25.10", ] rapids-cu12 = [ "cupy-cuda12x", @@ -46,7 +49,7 @@ rapids-cu12 = [ "cuml-cu12>=25.10", "cugraph-cu12>=25.10", "cuvs-cu12>=25.10", - "rmm-cu12>=25.10", + "librmm-cu12>=25.10", ] doc = [ @@ -164,7 +167,7 @@ sdist.include = [ "src/rapids_singlecell/_version.py" ] # Use abi3audit to catch issues with Limited API wheels [tool.cibuildwheel.linux] repair-wheel-command = [ - "auditwheel repair --exclude libcublas.so.12 --exclude libcublas.so.13 --exclude libcublasLt.so.12 --exclude libcublasLt.so.13 --exclude libcudart.so.12 --exclude libcudart.so.13 -w {dest_dir} {wheel}", + "auditwheel repair --exclude libcublas.so.12 --exclude libcublas.so.13 --exclude libcublasLt.so.12 --exclude libcublasLt.so.13 --exclude libcudart.so.12 --exclude libcudart.so.13 --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}", "pipx run abi3audit --strict --report {wheel}", ] [tool.cibuildwheel.macos] diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 8b6af5f6..3c42f60a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -23,147 +23,6 @@ __device__ __forceinline__ double wilcoxon_block_sum(double val, return 0.0; } -/** - * Kernel to compute tie correction factor for Wilcoxon test. - * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied - * values. - * - * Each block handles one column. Uses binary search to find tie groups. - * Assumes input is sorted column-wise (F-order). - */ -__global__ void tie_correction_kernel(const double* __restrict__ sorted_vals, - double* __restrict__ correction, - const int n_rows, const int n_cols) { - // Each block handles one column - int col = blockIdx.x; - if (col >= n_cols) return; - - const double* sv = sorted_vals + (size_t)col * n_rows; - - double local_sum = 0.0; - int tid = threadIdx.x; - - // Each thread processes positions where it detects END of a tie group - // Start from index 1, check if sv[i-1] != sv[i] (boundary detected) - // When at boundary, use binary search to find tie group size - for (int i = tid + 1; i <= n_rows; i += blockDim.x) { - // Detect boundary: either at the end, or value changed - bool at_boundary = (i == n_rows) || (sv[i] != sv[i - 1]); - - if (at_boundary) { - // Found end of tie group at position i-1 - // Binary search for start of this tie group - double val = sv[i - 1]; - int lo = 0, hi = i - 1; - while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { - lo = mid + 1; - } else { - hi = mid; - } - } - int tie_count = i - lo; - - // t^3 - t for this tie group - double t = (double)tie_count; - local_sum += t * t * t - t; - } - } - - // Warp-level reduction using shuffle -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); - } - - // Cross-warp reduction using small shared memory - __shared__ double warp_sums[32]; - int lane = tid & 31; - int warp_id = tid >> 5; - - if (lane == 0) { - warp_sums[warp_id] = local_sum; - } - __syncthreads(); - - // Final reduction in first warp - // Note: blockDim.x must be a multiple of 32 for correct warp reduction - if (tid < 32) { - double val = (tid < (blockDim.x >> 5)) ? warp_sums[tid] : 0.0; -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (tid == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - if (denom > 0) { - correction[col] = 1.0 - val / denom; - } else { - correction[col] = 1.0; - } - } - } -} - -/** - * Kernel to compute average ranks for each column. - * Uses scipy.stats.rankdata 'average' method: ties get the average of the ranks - * they would span. - * - * Each block handles one column. Assumes input is sorted column-wise (F-order). - */ -__global__ void average_rank_kernel(const double* __restrict__ sorted_vals, - const int* __restrict__ sorter, - double* __restrict__ ranks, - const int n_rows, const int n_cols) { - // Each thread block handles one column - int col = blockIdx.x; - if (col >= n_cols) return; - - // Pointers to this column's data - const double* sv = sorted_vals + (size_t)col * n_rows; - const int* si = sorter + (size_t)col * n_rows; - double* rk = ranks + (size_t)col * n_rows; - - // Each thread processes multiple rows - for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { - double val = sv[i]; - - // Binary search for tie_start (first element equal to val) - int lo = 0, hi = i; - while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { - lo = mid + 1; - } else { - hi = mid; - } - } - int tie_start = lo; - - // Binary search for tie_end (last element equal to val) - lo = i; - hi = n_rows - 1; - while (lo < hi) { - int mid = (lo + hi + 1) / 2; - if (sv[mid] > val) { - hi = mid - 1; - } else { - lo = mid; - } - } - int tie_end = lo; - - // Average rank for ties: (start + end + 2) / 2 (1-based ranks) - double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; - - // Write rank to original position - rk[si[i]] = avg_rank; - } -} - /** * OVO dense rank core. * diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 0ab5b26c..38fc25ec 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -15,29 +15,6 @@ static inline int round_up_to_warp(int n) { return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; } -static inline void launch_tie_correction(const double* sorted_vals, - double* correction, int n_rows, - int n_cols, cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - tie_correction_kernel<<>>(sorted_vals, correction, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(tie_correction_kernel); -} - -static inline void launch_average_rank(const double* sorted_vals, - const int* sorter, double* ranks, - int n_rows, int n_cols, - cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - average_rank_kernel<<>>(sorted_vals, sorter, ranks, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(average_rank_kernel); -} - static inline void launch_ovo_rank_dense( const float* ref_sorted, const float* grp_data, const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, @@ -79,31 +56,6 @@ template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - // Tie correction kernel - m.def( - "tie_correction", - [](gpu_array_f sorted_vals, - gpu_array correction, int n_rows, int n_cols, - std::uintptr_t stream) { - launch_tie_correction(sorted_vals.data(), correction.data(), n_rows, - n_cols, (cudaStream_t)stream); - }, - "sorted_vals"_a, "correction"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, - "stream"_a = 0); - - // Average rank kernel - m.def( - "average_rank", - [](gpu_array_f sorted_vals, - gpu_array_f sorter, - gpu_array_f ranks, int n_rows, int n_cols, - std::uintptr_t stream) { - launch_average_rank(sorted_vals.data(), sorter.data(), ranks.data(), - n_rows, n_cols, (cudaStream_t)stream); - }, - "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "stream"_a = 0); - m.def( "ovo_rank_dense", [](gpu_array_f ref_sorted, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index afac20f2..9fd626b6 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -49,15 +49,6 @@ __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, } } -static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { - size_t bytes = 0; - auto* dk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, - n_segments, doff, doff + 1, 0, 32); - return bytes; -} - /** * Tier 1 dispatch: when the largest group fits in shared memory, a fused * bitonic-sort + binary-search kernel handles the whole group per block. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 603c1c96..d30f92cc 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -2,151 +2,6 @@ #include -/** - * Fused rank-sum kernel: walk sorted data, compute per-group rank sums - * and tie correction without materializing a rank matrix. - * - * Each thread processes a CONTIGUOUS chunk of sorted elements, detecting - * tie groups by adjacent comparison (sequential access, no binary search). - * Cross-boundary ties are resolved via binary search at chunk boundaries. - * - * When use_gmem is false, per-group accumulators live in shared memory - * (fast atomics, limited to ~1500 groups on 48 KB devices). When use_gmem - * is true, accumulators write directly to ``rank_sums`` in global memory, - * supporting an arbitrary number of groups. The caller must pre-zero - * ``rank_sums`` before launching in the gmem path. - * - * Shared memory layout: - * use_gmem=false: (n_groups + 32) doubles (accumulators + warp buf) - * use_gmem=true: 32 doubles (warp buf only) - */ -__global__ void rank_sums_from_sorted_kernel( - const float* __restrict__ sorted_vals, - const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, - double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { - int col = blockIdx.x; - if (col >= n_cols) return; - - extern __shared__ double smem[]; - - double* grp_sums; - if (use_gmem) { - // Global memory path: write directly to output (must be pre-zeroed) - grp_sums = rank_sums + (size_t)col; // stride: n_cols - } else { - // Shared memory path: per-block accumulators - grp_sums = smem; - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - grp_sums[g] = 0.0; - } - __syncthreads(); - } - - const float* sv = sorted_vals + (size_t)col * n_rows; - const int* si = sorted_row_idx + (size_t)col * n_rows; - - int chunk = (n_rows + blockDim.x - 1) / blockDim.x; - int my_start = threadIdx.x * chunk; - int my_end = my_start + chunk; - if (my_end > n_rows) my_end = n_rows; - - double local_tie_sum = 0.0; - - // Stride for accumulator indexing: 1 for shared mem, n_cols for global mem - int acc_stride = use_gmem ? n_cols : 1; - - int i = my_start; - while (i < my_end) { - double val = sv[i]; - - int tie_local_end = i + 1; - while (tie_local_end < my_end && sv[tie_local_end] == val) - ++tie_local_end; - - int tie_global_start = i; - if (i == my_start && i > 0 && sv[i - 1] == val) { - int lo = 0, hi = i; - while (lo < hi) { - int mid = lo + (hi - lo) / 2; - if (sv[mid] < val) - lo = mid + 1; - else - hi = mid; - } - tie_global_start = lo; - } - - int tie_global_end = tie_local_end; - if (tie_local_end == my_end && tie_local_end < n_rows && - sv[tie_local_end] == val) { - int lo = tie_local_end, hi = n_rows - 1; - while (lo < hi) { - int mid = hi - ((hi - lo) >> 1); - if (sv[mid] > val) - hi = mid - 1; - else - lo = mid; - } - tie_global_end = lo + 1; - } - - int total_tie = tie_global_end - tie_global_start; - double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; - - for (int j = i; j < tie_local_end; ++j) { - int grp = group_codes[si[j]]; - if (grp < n_groups) { - atomicAdd(&grp_sums[grp * acc_stride], avg_rank); - } - } - - if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { - double t = (double)total_tie; - local_tie_sum += t * t * t - t; - } - - i = tie_local_end; - } - - __syncthreads(); - - // Copy shared memory accumulators to global output (smem path only) - if (!use_gmem) { - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; - } - } - - if (compute_tie_corr) { - // Warp buf sits after accumulator array in shared memory. - // gmem path: warp buf starts at smem[0]. - // smem path: n_groups doubles, then warp buf. - int warp_buf_off = use_gmem ? 0 : n_groups; - double* warp_buf = smem + warp_buf_off; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); - int lane = threadIdx.x & 31; - int wid = threadIdx.x >> 5; - if (lane == 0) warp_buf[wid] = local_tie_sum; - __syncthreads(); - if (threadIdx.x < 32) { - double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) - ? warp_buf[threadIdx.x] - : 0.0; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - val += __shfl_down_sync(0xffffffff, val, off); - if (threadIdx.x == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - tie_corr[col] = (denom > 0.0) ? (1.0 - val / denom) : 1.0; - } - } - } -} - /** * Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. * diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 90c54eb2..4fec5948 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -30,71 +30,6 @@ OVO_DEVICE_SPARSE_SUB_BATCH = 128 -def _average_ranks( - matrix: cp.ndarray, *, return_sorted: bool = False -) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: - """ - Compute average ranks for each column using GPU kernel. - - Uses scipy.stats.rankdata 'average' method: ties get the average - of the ranks they would span. - - Parameters - ---------- - matrix - Input matrix (n_rows, n_cols) - return_sorted - If True, also return sorted values (useful for tie correction) - - Returns - ------- - ranks or (ranks, sorted_vals) - """ - n_rows, n_cols = matrix.shape - - # Sort each column - sorter = cp.argsort(matrix, axis=0) - sorted_vals = cp.take_along_axis(matrix, sorter, axis=0) - - # Ensure F-order for kernel (columns contiguous in memory) - sorted_vals = cp.asfortranarray(sorted_vals) - sorter = cp.asfortranarray(sorter.astype(cp.int32)) - - stream = cp.cuda.get_current_stream().ptr - _wc.average_rank( - sorted_vals, sorter, matrix, n_rows=n_rows, n_cols=n_cols, stream=stream - ) - - if return_sorted: - return matrix, sorted_vals - return matrix - - -def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: - """ - Compute tie correction factor for Wilcoxon test. - - Takes pre-sorted values (column-wise) to avoid re-sorting. - Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) - where t is the count of tied values. - """ - n_rows, n_cols = sorted_vals.shape - correction = cp.ones(n_cols, dtype=cp.float64) - - if n_rows < 2: - return correction - - # Ensure F-order - sorted_vals = cp.asfortranarray(sorted_vals) - - stream = cp.cuda.get_current_stream().ptr - _wc.tie_correction( - sorted_vals, correction, n_rows=n_rows, n_cols=n_cols, stream=stream - ) - - return correction - - def _extract_dense_rows_cols( X, row_ids: np.ndarray, start: int, stop: int ) -> cp.ndarray: diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 87030dfb..413cfe3b 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -7,7 +7,7 @@ import pytest import scanpy as sc import scipy.sparse as sp -from scipy.stats import mannwhitneyu, rankdata, tiecorrect +from scipy.stats import mannwhitneyu import rapids_singlecell as rsc @@ -840,188 +840,3 @@ def test_sparse_matches_dense(self, perturbation_adata, sparse): np.testing.assert_array_equal( dense_df["pvals"].values, sparse_df["pvals"].values ) - - -# ============================================================================ -# Tests for ranking and tie correction kernels (edge cases from scipy) -# ============================================================================ - - -class TestRankingKernel: - """Tests for _average_ranks based on scipy.stats.rankdata edge cases.""" - - @pytest.fixture - def average_ranks(self): - """Import the ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - ) - - return _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_basic_ranking(self, average_ranks): - """Test basic average ranking on simple data.""" - values = [3.0, 1.0, 2.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_all_ties(self, average_ranks): - """All identical values should get the average rank.""" - values = [5.0, 5.0, 5.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_no_ties(self, average_ranks): - """All unique values should get sequential ranks.""" - values = [1.0, 2.0, 3.0, 4.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_mixed_ties(self, average_ranks): - """Mix of ties and unique values.""" - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_negative_values(self, average_ranks): - """Test with negative values.""" - values = [-3.0, -1.0, -2.0, 0.0, 1.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_single_element(self, average_ranks): - """Single element should have rank 1.""" - values = [42.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.0]) - - def test_two_elements_tied(self, average_ranks): - """Two tied elements should both have rank 1.5.""" - values = [7.0, 7.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.5, 1.5]) - - def test_multiple_columns(self, average_ranks): - """Test ranking across multiple columns independently.""" - col0 = [3.0, 1.0, 2.0] - col1 = [1.0, 1.0, 2.0] - data = np.column_stack([col0, col1]).astype(np.float64) - result = average_ranks(cp.asarray(data, order="F")) - - np.testing.assert_allclose(result.get()[:, 0], rankdata(col0, method="average")) - np.testing.assert_allclose(result.get()[:, 1], rankdata(col1, method="average")) - - -class TestTieCorrectionKernel: - """Tests for _tie_correction based on scipy.stats.tiecorrect edge cases.""" - - @pytest.fixture - def tie_correction(self): - """Import the tie correction function and ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - _tie_correction, - ) - - return _tie_correction, _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_no_ties(self, tie_correction): - """No ties should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 3.0, 4.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_all_ties(self, tie_correction): - """All tied values should give correction factor 0.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [5.0, 5.0, 5.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_mixed_ties(self, tie_correction): - """Mix of ties should give intermediate correction factor.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_two_elements_tied(self, tie_correction): - """Two tied elements.""" - _tie_correction, _average_ranks = tie_correction - - values = [7.0, 7.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_single_element(self, tie_correction): - """Single element should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [42.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - # Single element: n^3 - n = 0, so formula gives 1.0 - np.testing.assert_allclose(result.get()[0], 1.0, rtol=1e-10) - - def test_multiple_columns(self, tie_correction): - """Test tie correction across multiple columns independently.""" - _tie_correction, _average_ranks = tie_correction - - col0 = [1.0, 2.0, 3.0] # No ties - col1 = [5.0, 5.0, 5.0] # All ties - data = np.column_stack([col0, col1]).astype(np.float64) - _, sorted_vals = _average_ranks(cp.asarray(data, order="F"), return_sorted=True) - result = _tie_correction(sorted_vals) - - np.testing.assert_allclose( - result.get()[0], tiecorrect(rankdata(col0)), rtol=1e-10 - ) - np.testing.assert_allclose( - result.get()[1], tiecorrect(rankdata(col1)), rtol=1e-10 - ) - - def test_large_tie_groups(self, tie_correction): - """Test with large tie groups.""" - _tie_correction, _average_ranks = tie_correction - - # 50 values of 1, 50 values of 2 (non-multiple of 32 to test warp handling) - values = [1.0] * 50 + [2.0] * 50 - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) From 76389bbeeef5b030af0ac7763d15c42bc85a2ce5 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 18:10:14 +0200 Subject: [PATCH 04/36] update notebooks --- notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks b/notebooks index e5c97b34..4cdaa44f 160000 --- a/notebooks +++ b/notebooks @@ -1 +1 @@ -Subproject commit e5c97b34f4acbf919fb3118c987cc5893e5b5fdf +Subproject commit 4cdaa44fbd93b6f812fc8d2c72b89180ef92047d From f69f1d85e6f527ce71948c004f9ece281248ab1e Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 20:03:26 +0200 Subject: [PATCH 05/36] make dense faster --- CMakeLists.txt | 2 + .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 351 +++---------- .../_cuda/wilcoxon/wilcoxon.cu | 483 ++++++++++++++++-- .../tools/_rank_genes_groups/_wilcoxon.py | 140 ++--- tests/test_rank_genes_groups_wilcoxon.py | 34 ++ 5 files changed, 629 insertions(+), 381 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 85fcfc2d..e880613d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,8 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + target_sources(_wilcoxon_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) + target_link_libraries(_wilcoxon_cuda PRIVATE rmm::rmm) add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) target_sources(_wilcoxon_sparse_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) target_link_libraries(_wilcoxon_sparse_cuda PRIVATE rmm::rmm) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 3c42f60a..5af4e964 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -24,313 +24,118 @@ __device__ __forceinline__ double wilcoxon_block_sum(double val, } /** - * OVO dense rank core. + * OVR dense rank-sum kernel for data sorted by column. * - * ref_sorted is F-order and sorted independently for every column. - * grp_data is F-order and contains test-group rows concatenated by - * grp_offsets. One block computes one (column, test-group) result. - * - * This intentionally centralizes the OVO math; host/device and CSR/CSC/dense - * paths only need to materialize bounded dense column batches that feed this - * kernel. + * sorted_vals and sorted_row_idx are F-order arrays from a segmented + * SortPairs. One block owns one column, walks tie runs, and accumulates the + * average ranks per group without materializing a full rank matrix. */ -__global__ void ovo_rank_dense_kernel(const float* __restrict__ ref_sorted, - const float* __restrict__ grp_data, - const int* __restrict__ grp_offsets, - double* __restrict__ rank_sums, - double* __restrict__ tie_corr, int n_ref, - int n_all_grp, int n_cols, int n_groups, - bool compute_tie_corr) { +__global__ void rank_sums_from_sorted_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { int col = blockIdx.x; - int grp = blockIdx.y; - if (col >= n_cols || grp >= n_groups) return; - - int g_start = grp_offsets[grp]; - int g_end = grp_offsets[grp + 1]; - int n_grp = g_end - g_start; - - const float* ref_col = ref_sorted + (long long)col * n_ref; - const float* grp_col = grp_data + (long long)col * n_all_grp + g_start; - - __shared__ double warp_buf[32]; - double local_rank = 0.0; - - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - float v = grp_col[i]; + if (col >= n_cols) return; - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; + extern __shared__ double smem[]; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; + double* grp_sums; + if (use_gmem) { + grp_sums = rank_sums + (size_t)col; + } else { + grp_sums = smem; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; } - int n_eq_ref = lo - n_lt_ref; + __syncthreads(); + } - int n_lt_grp = 0; - int n_eq_grp = 0; - for (int j = 0; j < n_grp; ++j) { - float u = grp_col[j]; - n_lt_grp += (u < v); - n_eq_grp += (u == v); - } + const float* sv = sorted_vals + (size_t)col * n_rows; + const int* si = sorted_row_idx + (size_t)col * n_rows; - local_rank += (double)(n_lt_ref + n_lt_grp) + - ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; - } + int chunk = (n_rows + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > n_rows) my_end = n_rows; - double total_rank = wilcoxon_block_sum(local_rank, warp_buf); - if (threadIdx.x == 0) { - rank_sums[(size_t)grp * n_cols + col] = total_rank; - } + double local_tie_sum = 0.0; + int acc_stride = use_gmem ? n_cols : 1; - if (!compute_tie_corr) return; - __syncthreads(); + int i = my_start; + while (i < my_end) { + double val = sv[i]; - double local_tie = 0.0; + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) { + ++tie_local_end; + } - for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { - if (i == 0 || ref_col[i] != ref_col[i - 1]) { - float v = ref_col[i]; - int lo = i + 1, hi = n_ref; + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + int lo = 0; + int hi = i; while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; else - hi = m; - } - int count = lo - i; - for (int j = 0; j < n_grp; ++j) count += (grp_col[j] == v); - if (count > 1) { - double t = (double)count; - local_tie += t * t * t - t; + hi = mid; } + tie_global_start = lo; } - } - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - float v = grp_col[i]; - bool seen_in_group = false; - for (int j = 0; j < i; ++j) { - if (grp_col[j] == v) { - seen_in_group = true; - break; + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < n_rows && + sv[tie_local_end] == val) { + int lo = tie_local_end; + int hi = n_rows - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; } + tie_global_end = lo + 1; } - if (seen_in_group) continue; - - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - if (lo < n_ref && ref_col[lo] == v) continue; - - int count = 0; - for (int j = 0; j < n_grp; ++j) count += (grp_col[j] == v); - if (count > 1) { - double t = (double)count; - local_tie += t * t * t - t; - } - } - - double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); - if (threadIdx.x == 0) { - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - tie_corr[(size_t)grp * n_cols + col] = - (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } -} - -__global__ void ovo_rank_presorted_kernel(const float* __restrict__ ref_sorted, - const float* __restrict__ grp_sorted, - const int* __restrict__ grp_offsets, - double* __restrict__ rank_sums, - double* __restrict__ tie_corr, - int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr) { - int col = blockIdx.x; - int grp = blockIdx.y; - if (col >= n_cols || grp >= n_groups) return; - - int g_start = grp_offsets[grp]; - int g_end = grp_offsets[grp + 1]; - int n_grp = g_end - g_start; - - const float* ref_col = ref_sorted + (long long)col * n_ref; - const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; - - __shared__ double warp_buf[32]; - double local_rank = 0.0; - - int ref_lb = 0, ref_ub = 0; - int grp_lb = 0, grp_ub = 0; - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - float v = grp_col[i]; - - int lo = ref_lb, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; - ref_lb = n_lt_ref; - lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_ref = lo - n_lt_ref; - ref_ub = lo; + int total_tie = tie_global_end - tie_global_start; + double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; - lo = grp_lb; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] < v) - lo = m + 1; - else - hi = m; + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } } - int n_lt_grp = lo; - grp_lb = n_lt_grp; - lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; } - int n_eq_grp = lo - n_lt_grp; - grp_ub = lo; - local_rank += (double)(n_lt_ref + n_lt_grp) + - ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; - } - - double total_rank = wilcoxon_block_sum(local_rank, warp_buf); - if (threadIdx.x == 0) { - rank_sums[(size_t)grp * n_cols + col] = total_rank; + i = tie_local_end; } - if (!compute_tie_corr) return; __syncthreads(); - double local_tie = 0.0; - int grp_lb_tie = 0, grp_ub_tie = 0; - for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { - if (i == 0 || ref_col[i] != ref_col[i - 1]) { - float v = ref_col[i]; - int lo = i + 1, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt_ref = lo - i; - - lo = grp_lb_tie; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] < v) - lo = m + 1; - else - hi = m; - } - int lb = lo; - grp_lb_tie = lb; - - lo = (grp_ub_tie > lb) ? grp_ub_tie : lb; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt_grp = lo - lb; - grp_ub_tie = lo; - - int cnt = cnt_ref + cnt_grp; - if (cnt > 1) { - double t = (double)cnt; - local_tie += t * t * t - t; - } + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; } } - int ref_lb_tie = 0; - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - if (i == 0 || grp_col[i] != grp_col[i - 1]) { - float v = grp_col[i]; - int lo = ref_lb_tie, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - ref_lb_tie = lo; - if (lo < n_ref && ref_col[lo] == v) continue; - - lo = i + 1; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt = lo - i; - if (cnt > 1) { - double t = (double)cnt; - local_tie += t * t * t - t; - } + if (compute_tie_corr) { + int warp_buf_off = use_gmem ? 0 : n_groups; + double* warp_buf = smem + warp_buf_off; + double tie_sum = wilcoxon_block_sum(local_tie_sum, warp_buf); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; } } - - double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); - if (threadIdx.x == 0) { - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - tie_corr[(size_t)grp * n_cols + col] = - (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } } /** diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 38fc25ec..9212960b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -1,44 +1,20 @@ #include +#include + +#include +#include +#include + #include "../nb_types.h" #include "kernels_wilcoxon.cuh" +#include "wilcoxon_fast_common.cuh" +#include "kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovo_kernels.cuh" using namespace nb::literals; -// Constants for kernel launch configuration -constexpr int WARP_SIZE = 32; -constexpr int MAX_THREADS_PER_BLOCK = 512; -constexpr int OVO_THREADS_PER_BLOCK = 256; - -static inline int round_up_to_warp(int n) { - int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; -} - -static inline void launch_ovo_rank_dense( - const float* ref_sorted, const float* grp_data, const int* grp_offsets, - double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, cudaStream_t stream) { - dim3 block(OVO_THREADS_PER_BLOCK); - dim3 grid(n_cols, n_groups); - ovo_rank_dense_kernel<<>>( - ref_sorted, grp_data, grp_offsets, rank_sums, tie_corr, n_ref, - n_all_grp, n_cols, n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(ovo_rank_dense_kernel); -} - -static inline void launch_ovo_rank_presorted( - const float* ref_sorted, const float* grp_sorted, const int* grp_offsets, - double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, cudaStream_t stream) { - dim3 block(OVO_THREADS_PER_BLOCK); - dim3 grid(n_cols, n_groups); - ovo_rank_presorted_kernel<<>>( - ref_sorted, grp_sorted, grp_offsets, rank_sums, tie_corr, n_ref, - n_all_grp, n_cols, n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(ovo_rank_presorted_kernel); -} - static inline void launch_ovr_rank_dense( const float* sorted_vals, const int* sorter, const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, @@ -52,45 +28,435 @@ static inline void launch_ovr_rank_dense( CUDA_CHECK_LAST_ERROR(ovr_rank_dense_kernel); } +static void launch_ovr_rank_dense_streaming( + const float* block, const int* group_codes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols, cudaStream_t upstream_stream) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) { + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + } + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + if (sub_items > (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "Dense OVR sub-batch exceeds CUB int item limit"); + } + + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; ++i) { + cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking); + } + + cudaEvent_t inputs_ready; + cudaEventCreateWithFlags(&inputs_ready, cudaEventDisableTiming); + cudaEventRecord(inputs_ready, upstream_stream); + for (int i = 0; i < n_streams; ++i) { + cudaStreamWaitEvent(streams[i], inputs_ready, 0); + } + + RmmScratchPool pool; + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + const float* keys_in = block + (size_t)col * n_rows; + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + if (use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + ++batch_idx; + } + + for (int s = 0; s < n_streams; ++s) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("CUDA error in dense OVR streaming rank: ") + + cudaGetErrorString(err)); + } + } + cudaEventDestroy(inputs_ready); + for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); +} + +static void launch_ovo_rank_dense_tiered_impl( + const float* ref_data, bool ref_is_sorted, const float* grp_data, + const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols, cudaStream_t upstream_stream) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + + std::vector h_offsets(n_groups + 1); + cudaStreamSynchronize(upstream_stream); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + if (sub_ref_items > (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "Dense OVO reference sub-batch exceeds CUB int item limit"); + } + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + if (sub_grp_items > (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "Dense OVO sub-batch exceeds CUB int item limit"); + } + + size_t grp_cub_temp_bytes = 0; + if (needs_tier3) { + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, grp_cub_temp_bytes, fk, fk, (int)sub_grp_items, + max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t ref_cub_temp_bytes = 0; + if (!ref_is_sorted) { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_temp_bytes, fk, fk, (int)sub_ref_items, + sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; ++i) { + cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking); + } + + cudaEvent_t inputs_ready; + cudaEventCreateWithFlags(&inputs_ready, cudaEventDisableTiming); + cudaEventRecord(inputs_ready, upstream_stream); + for (int i = 0; i < n_streams; ++i) { + cudaStreamWaitEvent(streams[i], inputs_ready, 0); + } + + RmmScratchPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* ref_sorted; + int* ref_seg_offsets; + uint8_t* ref_cub_temp; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* grp_cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + if (ref_is_sorted) { + bufs[s].ref_sorted = nullptr; + bufs[s].ref_seg_offsets = nullptr; + bufs[s].ref_cub_temp = nullptr; + } else { + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); + } + bufs[s].grp_cub_temp = + needs_tier3 ? pool.alloc(grp_cub_temp_bytes) : nullptr; + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + const float* ref_sub = ref_data + (size_t)col * n_ref; + const float* grp_sub = grp_data + (size_t)col * n_all_grp; + if (!ref_is_sorted) { + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + size_t temp = ref_cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.ref_cub_temp, temp, ref_sub, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + ref_sub = buf.ref_sorted; + } + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, + n_all_grp, sb_cols, n_groups, compute_tie_corr, + skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, + n_all_grp, sb_cols, n_groups, compute_tie_corr, + skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, grp_sub, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + + size_t temp = grp_cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.grp_cub_temp, temp, grp_sub, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + ++batch_idx; + } + + for (int s = 0; s < n_streams; ++s) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("CUDA error in dense OVO tiered rank: ") + + cudaGetErrorString(err)); + } + } + cudaEventDestroy(inputs_ready); + for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); +} + +static void launch_ovo_rank_dense_tiered( + const float* ref_sorted, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + cudaStream_t upstream_stream) { + launch_ovo_rank_dense_tiered_impl(ref_sorted, true, grp_data, grp_offsets, + rank_sums, tie_corr, n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, upstream_stream); +} + +static void launch_ovo_rank_dense_tiered_unsorted_ref( + const float* ref_data, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + cudaStream_t upstream_stream) { + launch_ovo_rank_dense_tiered_impl(ref_data, false, grp_data, grp_offsets, + rank_sums, tie_corr, n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, upstream_stream); +} + template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; m.def( - "ovo_rank_dense", + "ovo_rank_dense_tiered", [](gpu_array_f ref_sorted, gpu_array_f grp_data, gpu_array_c grp_offsets, gpu_array_c rank_sums, gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { - launch_ovo_rank_dense( - ref_sorted.data(), grp_data.data(), grp_offsets.data(), - rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, (cudaStream_t)stream); + launch_ovo_rank_dense_tiered(ref_sorted.data(), grp_data.data(), + grp_offsets.data(), rank_sums.data(), + tie_corr.data(), n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, (cudaStream_t)stream); }, "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, + "stream"_a = 0); m.def( - "ovo_rank_presorted", - [](gpu_array_f ref_sorted, - gpu_array_f grp_sorted, + "ovo_rank_dense_tiered_unsorted_ref", + [](gpu_array_f ref_data, + gpu_array_f grp_data, gpu_array_c grp_offsets, gpu_array_c rank_sums, gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { - launch_ovo_rank_presorted( - ref_sorted.data(), grp_sorted.data(), grp_offsets.data(), + launch_ovo_rank_dense_tiered_unsorted_ref( + ref_data.data(), grp_data.data(), grp_offsets.data(), rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, (cudaStream_t)stream); + n_groups, compute_tie_corr, sub_batch_cols, + (cudaStream_t)stream); }, - "ref_sorted"_a, "grp_sorted"_a, "grp_offsets"_a, "rank_sums"_a, + "ref_data"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, + "stream"_a = 0); m.def( "ovr_rank_dense", @@ -108,6 +474,23 @@ void register_bindings(nb::module_& m) { "sorted_vals"_a, "sorter"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + + m.def( + "ovr_rank_dense_streaming", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + std::uintptr_t stream) { + launch_ovr_rank_dense_streaming( + block.data(), group_codes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols, (cudaStream_t)stream); + }, + "block"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS, "stream"_a = 0); } NB_MODULE(_wilcoxon_cuda, m) { diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 4fec5948..b96cfee6 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -21,13 +21,60 @@ MIN_GROUP_SIZE_WARNING = 25 DEFAULT_WILCOXON_CHUNK_SIZE = 512 -OVO_SORT_GROUP_THRESHOLD = 512 OVR_HOST_CSC_SUB_BATCH = 512 OVR_HOST_CSR_SUB_BATCH = 2048 OVR_DEVICE_CSC_SUB_BATCH = 2048 OVR_DEVICE_CSR_SUB_BATCH = 2048 OVO_HOST_SPARSE_SUB_BATCH = 256 OVO_DEVICE_SPARSE_SUB_BATCH = 128 +OVR_DENSE_SUB_BATCH = 64 +OVO_DENSE_TIERED_SUB_BATCH = 256 +DENSE_HOST_PRELOAD_MAX_GPU_FRACTION = 0.55 + + +def _maybe_preload_host_dense(rg: _RankGenes) -> None: + X = rg.X + if not isinstance(X, np.ndarray) or X.size == 0: + return + + try: + _, total = cp.cuda.runtime.memGetInfo() + except cp.cuda.runtime.CUDARuntimeError: + return + + if X.nbytes > total * DENSE_HOST_PRELOAD_MAX_GPU_FRACTION: + return + + registered = False + if X.flags.c_contiguous or X.flags.f_contiguous: + try: + cp.cuda.runtime.hostRegister(X.ctypes.data, X.nbytes, 0) + registered = True + except cp.cuda.runtime.CUDARuntimeError: + registered = False + + try: + X_gpu = cp.asarray(X) + cp.cuda.get_current_stream().synchronize() + except cp.cuda.memory.OutOfMemoryError: + cp.get_default_memory_pool().free_all_blocks() + return + except cp.cuda.runtime.CUDARuntimeError: + return + finally: + if registered: + try: + cp.cuda.runtime.hostUnregister(X.ctypes.data) + except cp.cuda.runtime.CUDARuntimeError: + pass + rg.X = X_gpu + + +def _get_dense_column_block_f32(X, start: int, stop: int) -> cp.ndarray: + """Extract a dense column block as F-order float32 CuPy memory.""" + if isinstance(X, np.ndarray | cp.ndarray): + return cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") + raise TypeError(f"Expected dense matrix, got {type(X)}") def _extract_dense_rows_cols( @@ -333,6 +380,7 @@ def wilcoxon( return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" + _maybe_preload_host_dense(rg) # Compute basic stats - uses Aggregate if on GPU, else defers to chunks rg._basic_stats() X = rg.X @@ -591,32 +639,29 @@ def _wilcoxon_vs_rest( for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) - # Slice and convert to dense GPU array (F-order for column ops) - block = _get_column_block(X, start, stop) - - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_vs_rest( - block, - start, - stop, - group_matrix=group_matrix, - group_sizes_dev=group_sizes_dev, - n_cells=n_cells, - ) + if rg._compute_stats_in_chunks: + block = _get_column_block(X, start, stop) + rg._accumulate_chunk_stats_vs_rest( + block, + start, + stop, + group_matrix=group_matrix, + group_sizes_dev=group_sizes_dev, + n_cells=n_cells, + ) + block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) + else: + block_f32 = _get_dense_column_block_f32(X, start, stop) - block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) - sorter = cp.asfortranarray(cp.argsort(block_f32, axis=0).astype(cp.int32)) - sorted_vals = cp.asfortranarray(cp.take_along_axis(block_f32, sorter, axis=0)) n_cols = stop - start - rank_sums = cp.zeros((n_groups, n_cols), dtype=cp.float64) + rank_sums = cp.empty((n_groups, n_cols), dtype=cp.float64) tie_corr = ( cp.empty(n_cols, dtype=cp.float64) if tie_correct else cp.ones(n_cols, dtype=cp.float64) ) - _wc.ovr_rank_dense( - sorted_vals, - sorter, + _wc.ovr_rank_dense_streaming( + block_f32, group_codes_gpu, rank_sums, tie_corr, @@ -624,6 +669,7 @@ def _wilcoxon_vs_rest( n_cols=n_cols, n_groups=n_groups, compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DENSE_SUB_BATCH, stream=cp.cuda.get_current_stream().ptr, ) expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 @@ -713,8 +759,6 @@ def _wilcoxon_with_reference( offsets_gpu = cp.asarray(offsets_np) n_all_grp = int(all_grp_row_ids.size) n_test = len(test_group_indices) - max_test_size = int(np.diff(offsets_np).max(initial=0)) - use_presorted_groups = max_test_size > OVO_SORT_GROUP_THRESHOLD test_sizes = cp.asarray( group_sizes[np.asarray(test_group_indices, dtype=np.intp)].astype( np.float64, copy=False @@ -976,45 +1020,25 @@ def _wilcoxon_with_reference( group_sizes=group_sizes, ) - ref_sorted = cp.asfortranarray(cp.sort(ref_block.astype(cp.float32), axis=0)) - grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) + ref_f32 = cp.asarray(ref_block, dtype=cp.float32, order="F") + grp_f32 = cp.asarray(grp_block, dtype=cp.float32, order="F") rank_sums = cp.empty((n_test, n_cols), dtype=cp.float64) tie_corr = cp.empty((n_test, n_cols), dtype=cp.float64) - if use_presorted_groups: - grp_rank_input = cp.empty_like(grp_f32) - for slot in range(n_test): - begin = int(offsets_np[slot]) - end = int(offsets_np[slot + 1]) - grp_rank_input[begin:end] = cp.sort(grp_f32[begin:end], axis=0) - grp_rank_input = cp.asfortranarray(grp_rank_input) - _wc.ovo_rank_presorted( - ref_sorted, - grp_rank_input, - offsets_gpu, - rank_sums, - tie_corr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_cols, - n_groups=n_test, - compute_tie_corr=tie_correct, - stream=cp.cuda.get_current_stream().ptr, - ) - else: - _wc.ovo_rank_dense( - ref_sorted, - grp_f32, - offsets_gpu, - rank_sums, - tie_corr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_cols, - n_groups=n_test, - compute_tie_corr=tie_correct, - stream=cp.cuda.get_current_stream().ptr, - ) + _wc.ovo_rank_dense_tiered_unsorted_ref( + ref_f32, + grp_f32, + offsets_gpu, + rank_sums, + tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DENSE_TIERED_SUB_BATCH, + stream=cp.cuda.get_current_stream().ptr, + ) n_combined = test_sizes + n_ref expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 413cfe3b..6e3dbf89 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -244,6 +244,40 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars assert params["reference"] == reference +def test_rank_genes_groups_wilcoxon_dense_ovr_ties_match_scanpy(): + rng = np.random.default_rng(16) + X = rng.integers(0, 40, size=(128, 7)).astype(np.float32) + labels = rng.integers(0, 7, size=128).astype(str) + adata_gpu = sc.AnnData( + X=X.copy(), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + adata_cpu = adata_gpu.copy() + + kw = { + "groupby": "group", + "method": "wilcoxon", + "reference": "rest", + "use_raw": False, + "tie_correct": True, + "n_genes": adata_gpu.n_vars, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for group in gpu_result["scores"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + np.testing.assert_allclose( + gpu_result["scores"][group], cpu_result["scores"][group], rtol=1e-13 + ) + np.testing.assert_allclose( + gpu_result["pvals"][group], cpu_result["pvals"][group], rtol=1e-13 + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) def test_rank_genes_groups_wilcoxon_honors_layer_and_use_raw(reference): """Test that layer parameter is respected.""" From a0e9b0c3b1fe701c22baad8e63f43de2dfda3bf9 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 23:08:15 +0200 Subject: [PATCH 06/36] update tests and fix issues --- .github/workflows/publish.yml | 6 +- .gitignore | 2 +- CMakeLists.txt | 14 +- pyproject.toml | 4 +- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 2 + .../_cuda/wilcoxon/wilcoxon.cu | 48 ++--- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 44 ++++- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 68 +++++-- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 128 +++++++++---- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 23 +-- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 53 ++++-- .../_cuda/wilcoxon/wilcoxon_rmm.cu | 4 +- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 23 +-- .../tools/_rank_genes_groups/__init__.py | 147 ++++----------- .../tools/_rank_genes_groups/_core.py | 68 ++++++- .../tools/_rank_genes_groups/_utils.py | 55 +++--- .../tools/_rank_genes_groups/_wilcoxon.py | 51 +++-- .../_rank_genes_groups/_wilcoxon_binned.py | 6 +- tests/test_rank_genes_groups_ttest.py | 35 ++-- tests/test_rank_genes_groups_wilcoxon.py | 177 +++++++++++++++--- 20 files changed, 626 insertions(+), 332 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4ca5d522..0ea4ee5e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -147,6 +147,8 @@ jobs: LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH PATH=/usr/local/cuda/bin:$PATH CIBW_BEFORE_BUILD: > + rm -f build/.librmm_dir && + mkdir -p build && python -m pip install -U pip scikit-build-core cmake ninja nanobind librmm-cu${{ matrix.cuda_major }} && @@ -157,8 +159,8 @@ jobs: ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so && ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && ldconfig && - python -c "import librmm; print(librmm.__path__[0])" > /tmp/.librmm_dir && - echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)" + python -c "import librmm; print(librmm.__path__[0])" > build/.librmm_dir && + echo "[rsc-build] marker=$(cat build/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" diff --git a/.gitignore b/.gitignore index 6994e147..4a7497ba 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,4 @@ CLAUDE.md # tmp_scripts tmp_scripts/ -benchmarks/ +/benchmarks/ diff --git a/CMakeLists.txt b/CMakeLists.txt index e880613d..4e404263 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,9 +50,19 @@ if (RSC_BUILD_EXTENSIONS) if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}") endif() - if(EXISTS "/tmp/.librmm_dir") - file(READ "/tmp/.librmm_dir" _rsc_librmm_marker) + # Wheel builds install librmm/rapids_logger into the isolated build env and + # write build/.librmm_dir from CIBW_BEFORE_BUILD. publish.yml also symlinks + # those shared libraries into /usr/local/lib so auditwheel can see and exclude + # them instead of bundling RAPIDS runtime libraries into the wheel. + if(DEFINED ENV{RSC_LIBRMM_DIR} AND EXISTS "$ENV{RSC_LIBRMM_DIR}/lib64/cmake/rmm/rmm-config.cmake") + set(_rsc_librmm_marker "$ENV{RSC_LIBRMM_DIR}") + elseif(EXISTS "${CMAKE_SOURCE_DIR}/build/.librmm_dir") + file(READ "${CMAKE_SOURCE_DIR}/build/.librmm_dir" _rsc_librmm_marker) string(STRIP "${_rsc_librmm_marker}" _rsc_librmm_marker) + else() + set(_rsc_librmm_marker "") + endif() + if(NOT "${_rsc_librmm_marker}" STREQUAL "" AND EXISTS "${_rsc_librmm_marker}/lib64/cmake/rmm/rmm-config.cmake") file(GLOB _rsc_marker_rmm_dirs "${_rsc_librmm_marker}/lib64/cmake/rmm") file(GLOB _rsc_marker_rapids_prefixes "${_rsc_librmm_marker}/lib64" diff --git a/pyproject.toml b/pyproject.toml index a3b07ede..b4940b18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,9 @@ requires = [ "nanobind>=2.0.0", "setuptools-scm>=8", # librmm headers/CMake config are needed at build time for Wilcoxon. - # CUDA wheel builds rewrite this to the matching cu12/cu13 package. + # Generic isolated source builds default to CUDA 12. CUDA wheel builds + # rewrite this to the matching cu12/cu13 package; CUDA 13 source builds + # should build in an existing RAPIDS env with --no-build-isolation. "librmm-cu12>=25.10", ] build-backend = "scikit_build_core.build" diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index 5b4c0b8c..a8e9ed4f 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -2,6 +2,8 @@ #include +#include "wilcoxon_fast_common.cuh" + // ============================================================================ // Warp reduction helper (sum doubles across block via warp_buf) // ============================================================================ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 9212960b..d314b289 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -41,17 +41,14 @@ static void launch_ovr_rank_dense_streaming( } size_t sub_items = (size_t)n_rows * sub_batch_cols; - if (sub_items > (size_t)std::numeric_limits::max()) { - throw std::runtime_error( - "Dense OVR sub-batch exceeds CUB int item limit"); - } + int sub_items_i32 = checked_cub_items(sub_items, "Dense OVR sub-batch"); size_t cub_temp_bytes = 0; { auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + nullptr, cub_temp_bytes, fk, fk, iv, iv, sub_items_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -97,7 +94,8 @@ static void launch_ovr_rank_dense_streaming( int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_items = n_rows * sb_cols; + int sb_items = checked_int_product((size_t)n_rows, (size_t)sb_cols, + "Dense OVR active sub-batch"); int s = batch_idx % n_streams; cudaStream_t stream = streams[s]; auto& buf = bufs[s]; @@ -184,32 +182,30 @@ static void launch_ovo_rank_dense_tiered_impl( n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; - if (sub_ref_items > (size_t)std::numeric_limits::max()) { - throw std::runtime_error( - "Dense OVO reference sub-batch exceeds CUB int item limit"); - } + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "Dense OVO reference sub-batch"); size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - if (sub_grp_items > (size_t)std::numeric_limits::max()) { - throw std::runtime_error( - "Dense OVO sub-batch exceeds CUB int item limit"); - } + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "Dense OVO group sub-batch"); size_t grp_cub_temp_bytes = 0; if (needs_tier3) { - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "Dense OVO group segment count"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, grp_cub_temp_bytes, fk, fk, (int)sub_grp_items, - max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); + nullptr, grp_cub_temp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); } size_t ref_cub_temp_bytes = 0; if (!ref_is_sorted) { auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_temp_bytes, fk, fk, (int)sub_ref_items, + nullptr, ref_cub_temp_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); } @@ -270,7 +266,9 @@ static void launch_ovo_rank_dense_tiered_impl( pool.alloc((size_t)n_groups * sub_batch_cols); if (needs_tier3) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_seg = n_sort_groups * sub_batch_cols; + int max_seg = checked_int_product((size_t)n_sort_groups, + (size_t)sub_batch_cols, + "Dense OVO group segment buffer"); bufs[s].grp_seg_offsets = pool.alloc(max_seg); bufs[s].grp_seg_ends = pool.alloc(max_seg); } else { @@ -287,8 +285,12 @@ static void launch_ovo_rank_dense_tiered_impl( int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_items_actual = n_ref * sb_cols; - int sb_grp_items_actual = n_all_grp * sb_cols; + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "Dense OVO active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "Dense OVO active group sub-batch"); int s = batch_idx % n_streams; cudaStream_t stream = streams[s]; auto& buf = bufs[s]; @@ -343,7 +345,9 @@ static void launch_ovo_rank_dense_tiered_impl( compute_tie_corr, padded_grp_size, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (needs_tier3) { - int sb_grp_seg = n_sort_groups * sb_cols; + int sb_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, + "Dense OVO active group segment count"); int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; build_tier3_seg_begin_end_offsets_kernel<<>>( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 2d5b3f2c..ec723b55 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -48,6 +49,39 @@ constexpr int TIER1_GROUP_THRESHOLD = 2500; // 512 MB per stream dense slab + same for sorted copy ≈ 1 GB / stream. constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; +static inline size_t wilcoxon_max_smem_per_block() { + int device = 0; + int max_smem = 0; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, + device); + return (size_t)max_smem; +} + +static inline int checked_cub_items(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds CUB int item limit"); + } + return (int)count; +} + +static inline int checked_int_span(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds int32 offset limit"); + } + return (int)count; +} + +static inline int checked_int_product(size_t a, size_t b, const char* context) { + if (a != 0 && b > (size_t)std::numeric_limits::max() / a) { + throw std::runtime_error(std::string(context) + + " exceeds int32 item limit"); + } + return (int)(a * b); +} + // --------------------------------------------------------------------------- // RAII guard for cudaHostRegister. Unregisters on scope exit even when an // exception unwinds — prevents leaked host pinning on stream-sync failures. @@ -60,9 +94,9 @@ struct HostRegisterGuard { if (p && bytes > 0) { cudaError_t err = cudaHostRegister(p, bytes, flags); if (err != cudaSuccess) { - // Already-registered memory is fine; anything else means the - // subsequent kernels would read garbage from an unmapped - // pointer, so surface the error immediately. + // Already-registered memory belongs to another owner; use it + // without unregistering here. Other failures mean mapped reads + // would be unsafe, so surface them immediately. if (err == cudaErrorHostMemoryAlreadyRegistered) { cudaGetLastError(); // clear sticky error flag } else { @@ -116,6 +150,10 @@ struct RmmScratchPool { template T* alloc(size_t count) { if (count == 0) count = 1; + if (count > std::numeric_limits::max() / sizeof(T)) { + throw std::runtime_error( + "Wilcoxon scratch allocation size overflow"); + } size_t bytes = count * sizeof(T); void* ptr = wilcoxon_rmm_allocate(bytes); bufs.push_back({ptr, bytes}); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index b195bee0..b60b87ff 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -64,17 +64,23 @@ static void ovo_streaming_csr_impl( size_t cub_temp_bytes = 0; if (needs_tier3) { size_t cub_grp_bytes = 0; - int max_grp_seg = n_sort_groups * sub_batch_cols; + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO device CSR group sub-batch"); + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSR group segment count"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); cub_temp_bytes = cub_grp_bytes; } std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + cudaStream_t ref_stream; + cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); int* d_sort_group_ids = nullptr; if (needs_tier3) { @@ -110,7 +116,9 @@ static void ovo_streaming_csr_impl( pool.alloc((size_t)n_groups * sub_batch_cols); if (needs_tier3) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_seg = n_sort_groups * sub_batch_cols; + int max_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSR group segment buffer"); bufs[s].grp_seg_offsets = pool.alloc(max_seg); bufs[s].grp_seg_ends = pool.alloc(max_seg); } else { @@ -127,6 +135,8 @@ static void ovo_streaming_csr_impl( for (int cache_col = 0; cache_col < n_cols; cache_col += ref_cache_cols) { int cache_cols = std::min(ref_cache_cols, n_cols - cache_col); size_t cache_ref_items = (size_t)n_ref * cache_cols; + int cache_ref_items_i32 = checked_cub_items( + cache_ref_items, "OVO device CSR reference cache"); ScopedCudaBuffer ref_dense_buf(cache_ref_items * sizeof(float)); ScopedCudaBuffer ref_sorted_buf(cache_ref_items * sizeof(float)); @@ -136,36 +146,39 @@ static void ovo_streaming_csr_impl( float* d_ref_sorted = (float*)ref_sorted_buf.data(); int* d_ref_seg_offsets = (int*)ref_seg_offsets_buf.data(); - cudaMemsetAsync(d_ref_dense, 0, cache_ref_items * sizeof(float)); + cudaMemsetAsync(d_ref_dense, 0, cache_ref_items * sizeof(float), + ref_stream); int tpb_ref_extract = round_up_to_warp(n_ref); int ref_blk = (n_ref + tpb_ref_extract - 1) / tpb_ref_extract; - csr_extract_dense_kernel<<>>( + csr_extract_dense_kernel<<>>( csr_data, csr_indices, csr_indptr, ref_row_ids, d_ref_dense, n_ref, cache_col, cache_col + cache_cols); CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, 0); + upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, ref_stream); size_t ref_cub_bytes = 0; auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_bytes, fk, fk, (int)cache_ref_items, cache_cols, + nullptr, ref_cub_bytes, fk, fk, cache_ref_items_i32, cache_cols, doff, doff + 1, BEGIN_BIT, END_BIT); ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); size_t ref_temp = ref_cub_bytes; cub::DeviceSegmentedRadixSort::SortKeys( ref_cub_temp_buf.data(), ref_temp, d_ref_dense, d_ref_sorted, - (int)cache_ref_items, cache_cols, d_ref_seg_offsets, - d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT); - cudaDeviceSynchronize(); + cache_ref_items_i32, cache_cols, d_ref_seg_offsets, + d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT, ref_stream); + cudaStreamSynchronize(ref_stream); int col = cache_col; int cache_stop = cache_col + cache_cols; int batch_idx = 0; while (col < cache_stop) { int sb_cols = std::min(sub_batch_cols, cache_stop - col); - int sb_grp_items_actual = n_all_grp * sb_cols; + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO device CSR active group sub-batch"); int s = batch_idx % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; @@ -224,7 +237,9 @@ static void ovo_streaming_csr_impl( compute_tie_corr, padded_grp_size, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (needs_tier3) { - int sb_grp_seg = n_sort_groups * sb_cols; + int sb_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sb_cols, + "OVO device CSR active group segment count"); { int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; @@ -277,6 +292,7 @@ static void ovo_streaming_csr_impl( } } for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); + cudaStreamDestroy(ref_stream); } /** @@ -316,23 +332,29 @@ static void ovo_streaming_csc_impl( size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "OVO device CSC reference sub-batch"); + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO device CSC group sub-batch"); size_t cub_ref_bytes = 0; { auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + nullptr, cub_ref_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); } size_t cub_temp_bytes = cub_ref_bytes; if (needs_tier3) { size_t cub_grp_bytes = 0; - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSC group segment count"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } @@ -380,7 +402,9 @@ static void ovo_streaming_csc_impl( pool.alloc((size_t)n_groups * sub_batch_cols); if (needs_tier3) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSC group segment buffer"); bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); } else { @@ -397,8 +421,12 @@ static void ovo_streaming_csc_impl( int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_items_actual = n_ref * sb_cols; - int sb_grp_items_actual = n_all_grp * sb_cols; + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "OVO device CSC active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO device CSC active group sub-batch"); int s = batch_idx % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; @@ -465,7 +493,9 @@ static void ovo_streaming_csc_impl( compute_tie_corr, padded_grp_size, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (needs_tier3) { - int sb_grp_seg = n_sort_groups * sb_cols; + int sb_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sb_cols, + "OVO device CSC active group segment count"); { int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; build_tier3_seg_begin_end_offsets_kernel<<(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + nullptr, cub_ref_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); } size_t cub_temp_bytes = cub_ref_bytes; if (needs_tier3) { size_t cub_grp_bytes = 0; - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO host CSC group segment count"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } @@ -82,8 +88,11 @@ static void ovo_streaming_csc_host_impl( int sb = std::min(sub_batch_cols, n_cols - col_start); IndptrT ptr_start = h_indptr[col_start]; int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i <= sb; i++) - off[i] = (int)(h_indptr[col_start + i] - ptr_start); + for (int i = 0; i <= sb; i++) { + off[i] = + checked_int_span((size_t)(h_indptr[col_start + i] - ptr_start), + "OVO host CSC rebased column offsets"); + } } int* d_all_offsets = pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); @@ -159,7 +168,9 @@ static void ovo_streaming_csc_host_impl( compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); if (needs_tier3) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO host CSC stream group segment count"); bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); } else { @@ -186,8 +197,12 @@ static void ovo_streaming_csc_host_impl( int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_actual = n_ref * sb_cols; - int sb_grp_actual = n_all_grp * sb_cols; + int sb_ref_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "OVO host CSC active reference sub-batch"); + int sb_grp_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO host CSC active group sub-batch"); int s = batch_idx % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; @@ -196,6 +211,7 @@ static void ovo_streaming_csc_host_impl( IndptrT ptr_start = h_indptr[col]; IndptrT ptr_end = h_indptr[col + sb_cols]; size_t nnz = (size_t)(ptr_end - ptr_start); + checked_int_span(nnz, "OVO host CSC active batch nnz"); cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, @@ -276,7 +292,9 @@ static void ovo_streaming_csc_host_impl( compute_tie_corr, padded_grp_size, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (needs_tier3) { - int sb_grp_seg = n_sort_groups * sb_cols; + int sb_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, + "OVO host CSC active group segment count"); { int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; build_tier3_seg_begin_end_offsets_kernel<< (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "OVO host CSR reference row exceeds int32 compacted nnz limit"); + } + int nnz_i = (int)row_nnz; + if ((size_t)h_ref_indptr_compact[i] + (size_t)nnz_i > + (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "OVO host CSR reference compacted nnz exceeds int32 limit"); + } h_ref_indptr_compact[i + 1] = h_ref_indptr_compact[i] + nnz_i; } int ref_nnz = h_ref_indptr_compact[n_ref]; @@ -462,8 +490,11 @@ static void ovo_streaming_csr_host_impl( if (pk.n_rows > max_pack_rows) max_pack_rows = pk.n_rows; if (pk.nnz > max_pack_nnz) max_pack_nnz = pk.nnz; if (K > max_pack_K) max_pack_K = K; - int pack_items = pk.n_rows * pk.sb_cols; + int pack_items = + checked_int_product((size_t)pk.n_rows, (size_t)pk.sb_cols, + "OVO host CSR pack dense slab"); if (pack_items > max_pack_items) max_pack_items = pack_items; + checked_int_span(pk.nnz, "OVO host CSR pack compacted nnz"); if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; } int max_group_rows = max_pack_rows; @@ -530,12 +561,37 @@ static void ovo_streaming_csr_host_impl( cudaMemcpyHostToDevice); // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- - float* d_ref_sorted = pool.alloc((size_t)n_ref * n_cols); + size_t ref_items = (size_t)n_ref * (size_t)n_cols; + if (n_ref > 0 && (size_t)n_cols > (size_t)std::numeric_limits::max() / + (size_t)n_ref) { + throw std::runtime_error( + "OVO host CSR dense reference cache exceeds CUB int item limit; " + "use native CSC/device sparse input or reduce genes/reference " + "size"); + } + if (ref_items > std::numeric_limits::max() / (2 * sizeof(float))) { + throw std::runtime_error( + "OVO host CSR dense reference cache size overflows size_t"); + } + size_t free_bytes = 0; + size_t total_bytes = 0; + if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess && + total_bytes > 0 && ref_items * 2 * sizeof(float) > total_bytes) { + throw std::runtime_error( + "OVO host CSR dense reference cache requires more GPU memory than " + "the device provides; use native CSC/device sparse input or reduce " + "genes/reference size"); + } + int ref_items_i32 = + checked_cub_items(ref_items, "OVO host CSR dense reference cache"); + float* d_ref_sorted = pool.alloc(ref_items); + cudaStream_t ref_stream; + cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); { ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); ScopedCudaBuffer ref_indptr_buf((n_ref + 1) * sizeof(int)); - ScopedCudaBuffer ref_dense_buf((size_t)n_ref * n_cols * sizeof(float)); + ScopedCudaBuffer ref_dense_buf(ref_items * sizeof(float)); ScopedCudaBuffer ref_seg_buf((n_cols + 1) * sizeof(int)); float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); @@ -552,7 +608,7 @@ static void ovo_streaming_csr_host_impl( // pass over PCIe, no intermediate native-dtype GPU buffer. if (n_ref > 0 && ref_nnz > 0) { csr_gather_cast_accumulate_mapped_kernel - <<>>( + <<>>( d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, d_ref_indptr, /*d_stats_codes=*/nullptr, /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, @@ -562,12 +618,12 @@ static void ovo_streaming_csr_host_impl( } // Extract ref dense (F-order) from compacted CSR. - cudaMemsetAsync(d_ref_dense, 0, (size_t)n_ref * n_cols * sizeof(float)); + cudaMemsetAsync(d_ref_dense, 0, ref_items * sizeof(float), ref_stream); { csr_extract_dense_identity_rows_unsorted_kernel - <<>>(d_ref_data_f32, d_ref_indices, - d_ref_indptr, d_ref_dense, n_ref, - 0, n_cols); + <<>>( + d_ref_data_f32, d_ref_indices, d_ref_indptr, d_ref_dense, + n_ref, 0, n_cols); CUDA_CHECK_LAST_ERROR( csr_extract_dense_identity_rows_unsorted_kernel); } @@ -578,18 +634,18 @@ static void ovo_streaming_csr_host_impl( auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_bytes, fk, fk, (int)((size_t)n_ref * n_cols), - n_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + nullptr, ref_cub_bytes, fk, fk, ref_items_i32, n_cols, doff, + doff + 1, BEGIN_BIT, END_BIT); } ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); - upload_linear_offsets(d_ref_seg, n_cols, n_ref, 0); + upload_linear_offsets(d_ref_seg, n_cols, n_ref, ref_stream); size_t temp = ref_cub_bytes; cub::DeviceSegmentedRadixSort::SortKeys( - cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, - (int)((size_t)n_ref * n_cols), n_cols, d_ref_seg, d_ref_seg + 1, - BEGIN_BIT, END_BIT); - cudaDeviceSynchronize(); + cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, ref_items_i32, + n_cols, d_ref_seg, d_ref_seg + 1, BEGIN_BIT, END_BIT, ref_stream); + cudaStreamSynchronize(ref_stream); } // ref scratch drops here + cudaStreamDestroy(ref_stream); // ---- Phase 2: Per-pack streaming ---- auto t1 = make_tier1_config(h_grp_offsets, n_test); @@ -604,11 +660,15 @@ static void ovo_streaming_csr_host_impl( size_t cub_grp_bytes = 0; if (may_need_cub && max_sub_items > 0) { + int max_sub_items_i32 = + checked_cub_items(max_sub_items, "OVO host CSR group pack"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); - int max_segments = max_pack_K * max_pack_sb_cols; + int max_segments = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR max group segment count"); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)max_sub_items, max_segments, + nullptr, cub_grp_bytes, fk, fk, max_sub_items_i32, max_segments, doff, doff + 1, BEGIN_BIT, END_BIT); } @@ -632,7 +692,9 @@ static void ovo_streaming_csr_host_impl( double* d_tie_corr; }; std::vector bufs(n_streams); - int max_pack_kernel_seg = max_pack_K * max_pack_sb_cols; + int max_pack_kernel_seg = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR pack segment buffer"); for (int s = 0; s < n_streams; s++) { bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); @@ -660,8 +722,6 @@ static void ovo_streaming_csr_host_impl( } } - cudaDeviceSynchronize(); // ensure Phase 1 done before Phase 2 streams - for (int p = 0; p < (int)packs.size(); p++) { const Pack& pack = packs[p]; int K = pack.end - pack.first; @@ -742,7 +802,9 @@ static void ovo_streaming_csr_host_impl( int col = 0; while (col < n_cols) { int sb_cols = std::min(pack_sb, n_cols - col); - int sb_items = pack_rows * sb_cols; + int sb_items = + checked_int_product((size_t)pack_rows, (size_t)sb_cols, + "OVO host CSR active group sub-batch"); cudaMemsetAsync(buf.d_grp_dense, 0, sb_items * sizeof(float), stream); @@ -798,7 +860,9 @@ static void ovo_streaming_csr_host_impl( upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (pack_has_above_t2) { - int n_seg = pack_n_sort_groups * sb_cols; + int n_seg = checked_int_product( + (size_t)pack_n_sort_groups, (size_t)sb_cols, + "OVO host CSR active group segment count"); { int blk = (n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; build_tier3_seg_begin_end_offsets_kernel<<< diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index 006002b9..2323e27f 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -4,7 +4,7 @@ template __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, const IndptrT* __restrict__ indptr, - int* __restrict__ col_counts, + unsigned int* __restrict__ col_counts, int n_rows, int n_cols) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= n_rows) return; @@ -12,7 +12,7 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, IndptrT re = indptr[row + 1]; for (IndptrT p = rs; p < re; ++p) { int c = (int)indices[p]; - if (c < n_cols) atomicAdd(&col_counts[c], 1); + if (c < n_cols) atomicAdd(&col_counts[c], 1u); } } @@ -49,24 +49,9 @@ __global__ void csr_scatter_to_csc_kernel( } } -/** - * Decide whether to use shared or global memory for OVR rank accumulators. - * Returns the smem size to request and sets use_gmem accordingly. - */ -static int query_max_smem_per_block() { - static int cached = -1; - if (cached < 0) { - int device; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, - device); - } - return cached; -} - static size_t ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(n_groups + 32) * sizeof(double); - if ((int)need <= query_max_smem_per_block()) { + if (need <= wilcoxon_max_smem_per_block()) { use_gmem = false; return need; } @@ -81,7 +66,7 @@ static size_t ovr_smem_config(int n_groups, bool& use_gmem) { */ static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); - if ((int)need <= query_max_smem_per_block()) { + if (need <= wilcoxon_max_smem_per_block()) { use_gmem = false; return need; } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 6eae2a28..257bbbb3 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -32,10 +32,12 @@ static void ovr_sparse_csc_host_streaming_impl( // CUB temp size for max_nnz items size_t cub_temp_bytes = 0; if (max_nnz > 0) { + int max_nnz_i32 = + checked_cub_items(max_nnz, "OVR host CSC sparse sub-batch nnz"); auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_nnz_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -98,8 +100,11 @@ static void ovr_sparse_csc_host_streaming_impl( int sb = std::min(sub_batch_cols, n_cols - col_start); IndptrT ptr_start = h_indptr[col_start]; int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i <= sb; i++) - off[i] = (int)(h_indptr[col_start + i] - ptr_start); + for (int i = 0; i <= sb; i++) { + off[i] = + checked_int_span((size_t)(h_indptr[col_start + i] - ptr_start), + "OVR host CSC rebased column offsets"); + } } int* d_all_offsets = pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); @@ -143,7 +148,8 @@ static void ovr_sparse_csc_host_streaming_impl( IndptrT ptr_start = h_indptr[col]; IndptrT ptr_end = h_indptr[col + sb_cols]; - int batch_nnz = (int)(ptr_end - ptr_start); + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "OVR host CSC active batch nnz"); // H2D: transfer sparse data for this column range (native dtype) if (batch_nnz > 0) { @@ -263,7 +269,7 @@ static void ovr_sparse_csr_host_streaming_impl( size_t total_nnz = (size_t)h_indptr[n_rows]; // ---- Phase 0: CPU planning in native CSR order ---- - std::vector h_col_counts(n_cols, 0); + std::vector h_col_counts(n_cols, 0); for (int row = 0; row < n_rows; row++) { IndptrT rs = h_indptr[row]; IndptrT re = h_indptr[row + 1]; @@ -282,7 +288,9 @@ static void ovr_sparse_csr_host_streaming_impl( int sb_cols = std::min(sub_batch_cols, n_cols - col_start); int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; for (int i = 0; i < sb_cols; i++) - off[i + 1] = off[i] + h_col_counts[col_start + i]; + off[i + 1] = checked_int_span( + (size_t)off[i] + (size_t)h_col_counts[col_start + i], + "OVR host CSR rebased column offsets"); h_batch_nnz[b] = (size_t)off[sb_cols]; if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; } @@ -295,10 +303,12 @@ static void ovr_sparse_csr_host_streaming_impl( // ---- Phase 1: allocate per-stream bounded work buffers ---- size_t cub_temp_bytes = 0; if (max_batch_nnz > 0) { + int max_batch_nnz_i32 = checked_cub_items( + max_batch_nnz, "OVR host CSR sparse sub-batch nnz"); auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_batch_nnz_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -427,7 +437,8 @@ static void ovr_sparse_csr_host_streaming_impl( int s = b % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; - int batch_nnz = (int)h_batch_nnz[b]; + int batch_nnz = + checked_int_span(h_batch_nnz[b], "OVR host CSR active batch nnz"); int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), @@ -546,10 +557,12 @@ static void ovr_sparse_csc_streaming_impl( // CUB temp size for max_nnz items size_t cub_temp_bytes = 0; if (max_nnz > 0) { + int max_nnz_i32 = + checked_cub_items(max_nnz, "OVR device CSC sparse sub-batch nnz"); auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_nnz_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -597,7 +610,8 @@ static void ovr_sparse_csc_streaming_impl( int ptr_start = h_indptr[col]; int ptr_end = h_indptr[col + sb_cols]; - int batch_nnz = ptr_end - ptr_start; + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "OVR device CSC active batch nnz"); // Compute rebased segment offsets on GPU (avoids host pinned-buffer // race) @@ -684,16 +698,16 @@ static void ovr_sparse_csr_streaming_impl( // ---- Phase 0: Planning — count nnz per column via histogram ---- RmmScratchPool pool; - int* d_col_counts = pool.alloc(n_cols); - cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); + unsigned int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(unsigned int)); { int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; csr_col_histogram_kernel<<>>( csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); } - std::vector h_col_counts(n_cols); - cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(unsigned int), cudaMemcpyDeviceToHost); // Per-batch prefix sums on host @@ -710,7 +724,9 @@ static void ovr_sparse_csr_streaming_impl( int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; off[0] = 0; for (int i = 0; i < sb_cols; i++) - off[i + 1] = off[i] + h_col_counts[col_start + i]; + off[i + 1] = checked_int_span( + (size_t)off[i] + (size_t)h_col_counts[col_start + i], + "OVR device CSR rebased column offsets"); h_batch_nnz[b] = (size_t)off[sb_cols]; if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; } @@ -724,10 +740,12 @@ static void ovr_sparse_csr_streaming_impl( // ---- Phase 1: Allocate per-stream buffers ---- size_t cub_temp_bytes = 0; if (max_batch_nnz > 0) { + int max_batch_nnz_i32 = checked_cub_items( + max_batch_nnz, "OVR device CSR sparse sub-batch nnz"); auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_batch_nnz_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -796,7 +814,8 @@ static void ovr_sparse_csr_streaming_impl( int s = b % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; - int batch_nnz = (int)h_batch_nnz[b]; + int batch_nnz = + checked_int_span(h_batch_nnz[b], "OVR device CSR active batch nnz"); // D2D copy pre-computed col_offsets for this batch int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu index 26e37f42..94a101e9 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu @@ -10,8 +10,8 @@ void* wilcoxon_rmm_allocate(size_t bytes) { return rmm::mr::get_current_device_resource()->allocate_sync(bytes); } catch (std::exception const& e) { throw std::runtime_error( - std::string("RMM allocation failed in Wilcoxon scratch: ") + - e.what()); + std::string("RMM allocation failed in Wilcoxon scratch (") + + std::to_string(bytes) + " bytes): " + e.what()); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index d30f92cc..efdac894 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -25,6 +25,10 @@ * grp_nz_count[n_groups] nonzero-per-group counters * warp_buf[32] tie-correction reduction scratch * + * n_rows is the ranking population, including rows whose group code is the + * n_groups sentinel. Sentinel rows contribute to the "rest" distribution and + * tie-correction denominator but do not receive rank-sum accumulation. + * * Grid: (sb_cols,) Block: (tpb,) */ template @@ -223,28 +227,11 @@ __global__ void rank_sums_sparse_ovr_kernel( } } -/** - * Decide whether the host cast+stats kernels can use per-block shared memory - * accumulators. Large group counts exceed the dynamic smem launch limit, so - * those cases fall back to direct global-memory atomics after zeroing the - * per-stream output buffers. - */ -static int wilcoxon_cast_max_smem_per_block() { - static int cached = -1; - if (cached < 0) { - int device; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, - device); - } - return cached; -} - static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, bool compute_nnz, bool& use_gmem) { int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); size_t need = (size_t)n_arrays * n_groups * sizeof(double); - if (need <= (size_t)wilcoxon_cast_max_smem_per_block()) { + if (need <= wilcoxon_max_smem_per_block()) { use_gmem = false; return need; } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index d399a301..a204d73e 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -21,100 +21,31 @@ ] -class _LazyRankGenesColumn: - def __init__( - self, - values: np.ndarray | None = None, - *, - var_names: np.ndarray | None = None, - gene_indices: np.ndarray | None = None, - dtype: str | np.dtype, - ) -> None: - self._values = values - self._var_names = var_names - self._gene_indices = gene_indices - self._dtype = np.dtype(dtype) - - def __len__(self) -> int: - if self._values is not None: - return int(self._values.shape[0]) - return int(self._gene_indices.shape[0]) - - def __getitem__(self, key): - if self._values is not None: - return self._values[key] - return self._var_names[self._gene_indices[key]] - - def __iter__(self): - for idx in range(len(self)): - yield self[idx] - - def __array__(self, dtype=None, copy=None) -> np.ndarray: - if self._values is not None: - arr = np.asarray(self._values, dtype=self._dtype) - else: - arr = np.asarray(self._var_names[self._gene_indices], dtype=self._dtype) - if dtype is not None: - arr = np.asarray(arr, dtype=dtype) - if copy: - arr = arr.copy() - return arr - - -class _LazyRankGenesRecords(dict): - def __init__( - self, group_names: np.ndarray, columns: dict[str, object], dtype: str | np.dtype - ) -> None: - super().__init__(columns) - self._group_names = tuple(str(name) for name in group_names) - self._dtype = np.dtype([(name, np.dtype(dtype)) for name in self._group_names]) - - @property - def dtype(self) -> np.dtype: - return self._dtype - - def __getitem__(self, key): - if isinstance(key, str): - return super().__getitem__(key) - return np.asarray(self)[key] - - def __array__(self, dtype=None, copy=None) -> np.ndarray: - out = np.empty(len(next(iter(self.values()))) if self else 0, dtype=self._dtype) - for name in self._group_names: - out[name] = np.asarray(super().__getitem__(name)) - if dtype is not None: - out = np.asarray(out, dtype=dtype) - if copy: - out = out.copy() - return out - - def copy(self) -> np.ndarray: - return np.asarray(self).copy() - - -def _array_result_to_lazy_records( +def _array_result_to_records( arrays: dict[str, object], field: str, dtype: str | np.dtype -) -> _LazyRankGenesRecords: - group_names = arrays["group_names"] - values = arrays[field] - columns = { - str(group_name): _LazyRankGenesColumn(values[row], dtype=dtype) - for row, group_name in enumerate(group_names) - } - return _LazyRankGenesRecords(group_names, columns, dtype) - - -def _array_result_to_lazy_names(arrays: dict[str, object]) -> _LazyRankGenesRecords: - group_names = arrays["group_names"] - var_names = arrays["var_names"] - gene_indices = arrays["gene_indices"] - columns = { - str(group_name): _LazyRankGenesColumn( - var_names=var_names, gene_indices=gene_indices[row], dtype=object - ) - for row, group_name in enumerate(group_names) - } - return _LazyRankGenesRecords(group_names, columns, object) +) -> np.ndarray: + group_names = tuple(str(name) for name in arrays["group_names"]) + values = np.asarray(arrays[field]) + out = np.empty( + values.shape[1], + dtype=[(group_name, np.dtype(dtype)) for group_name in group_names], + ) + for row, group_name in enumerate(group_names): + out[group_name] = values[row] + return out + + +def _array_result_to_names(arrays: dict[str, object]) -> np.ndarray: + group_names = tuple(str(name) for name in arrays["group_names"]) + var_names = np.asarray(arrays["var_names"]) + gene_indices = np.asarray(arrays["gene_indices"], dtype=np.intp) + out = np.empty( + gene_indices.shape[1], + dtype=[(group_name, object) for group_name in group_names], + ) + for row, group_name in enumerate(group_names): + out[group_name] = var_names[gene_indices[row]] + return out def rank_genes_groups( @@ -146,8 +77,8 @@ def rank_genes_groups( Rank genes for characterizing groups using GPU acceleration. Expects nonnegative expression data. Log1p/log-normalized data is expected - for biologically meaningful log fold changes; sparse inputs with explicit - negative values are rejected. + for biologically meaningful log fold changes; negative values are rejected + for eager in-memory inputs. .. note:: **Dask support:** `'t-test'`, `'t-test_overestim_var'`, and @@ -235,10 +166,8 @@ def rank_genes_groups( Returns ------- - Updates `adata` with the following fields. Rank result fields are lazy - Scanpy-compatible record objects: group fields can be indexed like - structured arrays, while full structured arrays are materialized only when - requested through NumPy conversion or `.copy()`. + Updates `adata` with the following fields. Rank result fields are + Scanpy-compatible structured arrays. `adata.uns['rank_genes_groups' | key_added]['names']` Structured array to be indexed by group id storing the gene @@ -269,7 +198,7 @@ def rank_genes_groups( if "return_format" in kwds: msg = ( "return_format has been removed; rank_genes_groups always writes " - "lazy Scanpy-compatible results to adata.uns." + "Scanpy-compatible structured results to adata.uns." ) raise TypeError(msg) @@ -357,23 +286,15 @@ def rank_genes_groups( arrays = test_obj.stats_arrays or {} adata.uns[key_added] = {"params": params} if arrays and len(arrays.get("group_names", ())) > 0: - adata.uns[key_added]["names"] = _array_result_to_lazy_names(arrays) - for col, dtype in { - "scores": "float32", - "logfoldchanges": "float32", - "pvals": "float64", - "pvals_adj": "float64", - }.items(): + adata.uns[key_added]["names"] = _array_result_to_names(arrays) + for col in ("scores", "logfoldchanges", "pvals", "pvals_adj"): if col in arrays: values = arrays[col] - if hasattr(values, "dtype"): - dtype = values.dtype - adata.uns[key_added][col] = _array_result_to_lazy_records( - arrays, col, dtype - ) + dtype = values.dtype + adata.uns[key_added][col] = _array_result_to_records(arrays, col, dtype) + groups_names = [str(name) for name in test_obj.groups_order] if test_obj.pts is not None: - groups_names = [str(name) for name in test_obj.groups_order] adata.uns[key_added]["pts"] = pd.DataFrame( test_obj.pts.T, index=test_obj.var_names, columns=groups_names ) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index acfbe2e2..af91e4d5 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -35,6 +35,41 @@ """, "fdr_bh_reverse_cummin", ) +_GROUP_CHUNK_STATS_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ void group_chunk_stats( + const double* block, + const int* group_codes, + double* group_sums, + double* group_sum_sq, + double* group_nnz, + const int n_rows, + const int n_cols, + const int n_groups, + const bool compute_nnz +) { + const long long idx = blockIdx.x * blockDim.x + threadIdx.x; + const long long total = static_cast(n_rows) * n_cols; + if (idx >= total) { + return; + } + const int row = idx % n_rows; + const int col = idx / n_rows; + const int group = group_codes[row]; + if (group < 0 || group >= n_groups) { + return; + } + const double value = block[idx]; + const long long out = static_cast(group) * n_cols + col; + atomicAdd(group_sums + out, value); + atomicAdd(group_sum_sq + out, value * value); + if (compute_nnz && value != 0.0) { + atomicAdd(group_nnz + out, 1.0); + } +} +""", + "group_chunk_stats", +) _RANK_SORT_MIN_ELEMENTS = 1_000_000 _RANK_SORT_MAX_WORKERS = 64 @@ -258,7 +293,7 @@ def _accumulate_chunk_stats_vs_rest( start: int, stop: int, *, - group_matrix: cp.ndarray, + group_codes_dev: cp.ndarray, group_sizes_dev: cp.ndarray, n_cells: int, ) -> None: @@ -268,9 +303,31 @@ def _accumulate_chunk_stats_vs_rest( rest_sizes = n_cells - group_sizes_dev - # Group sums and sum of squares - group_sums = group_matrix.T @ block - group_sum_sq = group_matrix.T @ (block**2) + n_groups = len(self.groups_order) + n_cols = stop - start + group_sums = cp.zeros((n_groups, n_cols), dtype=cp.float64) + group_sum_sq = cp.zeros((n_groups, n_cols), dtype=cp.float64) + group_nnz = ( + cp.zeros((n_groups, n_cols), dtype=cp.float64) if self.comp_pts else None + ) + n_items = n_cells * n_cols + threads = 256 + blocks = (n_items + threads - 1) // threads + _GROUP_CHUNK_STATS_KERNEL( + (blocks,), + (threads,), + ( + block, + group_codes_dev, + group_sums, + group_sum_sq, + group_nnz if group_nnz is not None else group_sums, + np.int32(n_cells), + np.int32(n_cols), + np.int32(n_groups), + self.comp_pts, + ), + ) # Means chunk_means = group_sums / group_sizes_dev[:, None] @@ -283,7 +340,6 @@ def _accumulate_chunk_stats_vs_rest( # Pts (fraction expressing) if self.comp_pts: - group_nnz = group_matrix.T @ (block != 0).astype(cp.float64) self.pts[:, start:stop] = cp.asnumpy(group_nnz / group_sizes_dev[:, None]) # Rest statistics @@ -439,7 +495,7 @@ def compute_statistics( raise ValueError(msg) self._score_dtype = np.dtype(np.float64 if return_u_values else np.float32) self._wilcoxon_gpu_result = None - self._store_wilcoxon_gpu_result = n_genes_user is not None + self._store_wilcoxon_gpu_result = True try: test_results = self.wilcoxon( tie_correct=tie_correct, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 4ec37e40..de91e25d 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -16,32 +16,47 @@ EPS = 1e-9 WARP_SIZE = 32 MAX_THREADS_PER_BLOCK = 512 +MIN_GROUP_SIZE_WARNING = 25 + + +def _nonnegative_error(prefix: str) -> ValueError: + msg = ( + f"{prefix} contains negative values. rank_genes_groups expects " + "nonnegative expression values; use raw counts or log1p/log-normalized " + "expression, not scaled or centered data." + ) + return ValueError(msg) def _check_sparse_nonnegative(X) -> None: - """Reject sparse matrices with explicit negative values. + """Reject inputs with negative values where an eager check is cheap. Sparse rank_genes_groups code treats missing entries as true expression zeros. Optimized sparse Wilcoxon paths may rank explicit nonzeros and add implicit zeros analytically, which is only valid when explicit sparse values are nonnegative expression values. """ + dtype = None + if sp.issparse(X) or cpsp.issparse(X): + dtype = np.dtype(X.data.dtype) + elif isinstance(X, np.ndarray | cp.ndarray): + dtype = np.dtype(X.dtype) + if dtype is not None and dtype.kind == "c": + msg = "rank_genes_groups does not support complex expression values." + raise TypeError(msg) + if sp.issparse(X): if X.nnz > 0 and float(X.data.min()) < 0: - msg = ( - "Sparse input contains negative values. rank_genes_groups " - "expects nonnegative expression values; use raw counts or " - "log1p/log-normalized expression, not scaled or centered data." - ) - raise ValueError(msg) + raise _nonnegative_error("Sparse input") elif cpsp.issparse(X): if X.nnz > 0 and float(X.data.min()) < 0: - msg = ( - "Sparse input contains negative values. rank_genes_groups " - "expects nonnegative expression values; use raw counts or " - "log1p/log-normalized expression, not scaled or centered data." - ) - raise ValueError(msg) + raise _nonnegative_error("Sparse input") + elif isinstance(X, np.ndarray): + if X.size > 0 and float(np.nanmin(X)) < 0: + raise _nonnegative_error("Dense input") + elif isinstance(X, cp.ndarray): + if X.size > 0 and float(cp.nanmin(X)) < 0: + raise _nonnegative_error("Dense input") def _select_groups( @@ -140,20 +155,6 @@ def _round_up_to_warp(n: int) -> int: return min(MAX_THREADS_PER_BLOCK, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) -def _select_top_n(scores: NDArray, n_top: int) -> NDArray: - """Select indices of top n scores. - - Uses argpartition + argsort for O(n + k log k) complexity where k = n_top. - This is faster than full sorting when k << n. - """ - n_from = scores.shape[0] - reference_indices = np.arange(n_from, dtype=int) - partition = np.argpartition(scores, -n_top)[-n_top:] - partial_indices = np.argsort(scores[partition])[::-1] - global_indices = reference_indices[partition][partial_indices] - return global_indices - - def _choose_chunk_size(requested: int | None) -> int: """Choose chunk size for gene processing.""" if requested is not None: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index b96cfee6..880da7e0 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -12,14 +12,13 @@ from rapids_singlecell._cuda import _wilcoxon_cuda as _wc from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs -from ._utils import EPS, _choose_chunk_size, _get_column_block +from ._utils import EPS, MIN_GROUP_SIZE_WARNING, _choose_chunk_size, _get_column_block if TYPE_CHECKING: from numpy.typing import NDArray from ._core import _RankGenes -MIN_GROUP_SIZE_WARNING = 25 DEFAULT_WILCOXON_CHUNK_SIZE = 512 OVR_HOST_CSC_SUB_BATCH = 512 OVR_HOST_CSR_SUB_BATCH = 2048 @@ -29,10 +28,11 @@ OVO_DEVICE_SPARSE_SUB_BATCH = 128 OVR_DENSE_SUB_BATCH = 64 OVO_DENSE_TIERED_SUB_BATCH = 256 -DENSE_HOST_PRELOAD_MAX_GPU_FRACTION = 0.55 +DENSE_HOST_PRELOAD_MAX_GPU_FRACTION = 0.55 # leave headroom for rank buffers def _maybe_preload_host_dense(rg: _RankGenes) -> None: + """Preload moderate host-dense matrices to avoid repeated chunk transfers.""" X = rg.X if not isinstance(X, np.ndarray) or X.size == 0: return @@ -259,7 +259,20 @@ def _wilcoxon_scores( def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool): - is_f64 = X.data.dtype == np.float64 + data_dtype = np.dtype(X.data.dtype) + if data_dtype == np.float64: + is_f64 = True + data_arr = X.data + elif data_dtype == np.float32 or data_dtype.kind in {"b", "i", "u"}: + is_f64 = False + data_arr = X.data.astype(np.float32, copy=False) + else: + msg = ( + "Wilcoxon sparse input data dtype must be float32, float64, bool, " + f"or integer; got {data_dtype}." + ) + raise TypeError(msg) + is_idx64 = support_idx64 and X.indices.dtype == np.int64 is_i64 = X.indptr.dtype == np.int64 suffix = "" @@ -270,15 +283,33 @@ def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool if is_i64: suffix += "_i64" fn = getattr(module, base_name + suffix) - data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) indices_arr = X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) return fn, data_arr, indices_arr def _device_sparse_arrays_i32_f32(X): + data_dtype = np.dtype(X.data.dtype) + if data_dtype == np.float32 or data_dtype == np.float64: + pass + elif data_dtype.kind in {"b", "i", "u"}: + pass + else: + msg = ( + "Wilcoxon device sparse input data dtype must be float32, float64, " + f"bool, or integer; got {data_dtype}." + ) + raise TypeError(msg) + if X.indptr.dtype != cp.int32: max_indptr = int(cp.asnumpy(X.indptr[-1])) if max_indptr > np.iinfo(np.int32).max: + warnings.warn( + "Wilcoxon device sparse path requires int32 indptr for CUDA " + "kernels; falling back to the bounded dense chunk path because " + f"nnz={max_indptr} exceeds int32.", + RuntimeWarning, + stacklevel=3, + ) return None data = X.data.astype(cp.float32, copy=False) indices = X.indices.astype(cp.int32, copy=False) @@ -620,12 +651,6 @@ def _wilcoxon_vs_rest( return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) - group_matrix = None - if rg._compute_stats_in_chunks: - codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) - group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) - valid_idx = cp.where(codes_gpu < n_groups)[0] - group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev @@ -645,7 +670,7 @@ def _wilcoxon_vs_rest( block, start, stop, - group_matrix=group_matrix, + group_codes_dev=group_codes_gpu, group_sizes_dev=group_sizes_dev, n_cells=n_cells, ) @@ -838,6 +863,8 @@ def _wilcoxon_with_reference( ) else: csr = X + # Host CSR gather scans each row's native index list and tolerates + # unsorted row indices; avoid a full CSR copy just to sort. csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( _wcs, "ovo_streaming_csr_host", csr, support_idx64=True ) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index 70d049af..14793834 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -11,6 +11,8 @@ from rapids_singlecell._compat import DaskArray from rapids_singlecell._cuda import _wilcoxon_binned_cuda as _wb +from ._utils import MIN_GROUP_SIZE_WARNING + if TYPE_CHECKING: from numpy.typing import NDArray @@ -159,7 +161,7 @@ def wilcoxon_binned( ): if gi == ireference: continue - if size <= 25 or n_ref <= 25: + if size <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: warnings.warn( f"Group {name} has size {size} (reference {n_ref}); normal " "approximation of the Wilcoxon statistic may be inaccurate.", @@ -169,7 +171,7 @@ def wilcoxon_binned( else: for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size - if size <= 25 or rest <= 25: + if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: warnings.warn( f"Group {name} has size {size} (rest {rest}); normal " "approximation of the Wilcoxon statistic may be inaccurate.", diff --git a/tests/test_rank_genes_groups_ttest.py b/tests/test_rank_genes_groups_ttest.py index 24a40721..719fb939 100644 --- a/tests/test_rank_genes_groups_ttest.py +++ b/tests/test_rank_genes_groups_ttest.py @@ -1,5 +1,6 @@ from __future__ import annotations +import anndata as ad import numpy as np import pytest import scanpy as sc @@ -10,6 +11,10 @@ import rapids_singlecell as rsc +def _make_nonnegative(adata): + adata.X = np.abs(adata.X) + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("method", ["t-test", "t-test_overestim_var"]) @pytest.mark.parametrize("sparse", [True, False]) @@ -18,12 +23,15 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) if sparse: - adata_gpu.X = np.abs(adata_gpu.X).astype(np.float32) + adata_gpu.X = adata_gpu.X.astype(np.float32) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() + if sparse: + adata_cpu.X = adata_cpu.X.astype(np.float64) rsc.tl.rank_genes_groups( adata_gpu, @@ -53,19 +61,12 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] - rtol = 1e-6 if sparse else 1e-13 - if sparse and field in {"scores", "logfoldchanges"}: - atol = 1e-6 - elif sparse: - atol = 1e-12 - else: - atol = 1e-15 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) np.testing.assert_allclose( - gpu_values, cpu_values, rtol=rtol, atol=atol, equal_nan=True + gpu_values, cpu_values, rtol=1e-13, atol=1e-15, equal_nan=True ) params = gpu_result["params"] @@ -83,6 +84,7 @@ def test_rank_genes_groups_ttest_honors_layer_and_use_raw(reference, method): np.random.seed(42) base = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=150) base.obs["blobs"] = base.obs["blobs"].astype("category") + _make_nonnegative(base) base.layers["signal"] = base.X.copy() ref_adata = base.copy() @@ -131,6 +133,7 @@ def test_rank_genes_groups_ttest_subset_and_bonferroni(reference, method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=4, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) groups = ["0", "1", "2"] if reference != "rest" else ["0", "2"] @@ -169,6 +172,7 @@ def test_rank_genes_groups_ttest_with_renamed_categories( np.random.seed(42) adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # First run with original category names rsc.tl.rank_genes_groups(adata, "blobs", method=method, reference=reference_before) @@ -197,6 +201,7 @@ def test_rank_genes_groups_ttest_with_unsorted_groups(reference, method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) bdata = adata.copy() groups = ["0", "1", "2", "3"] if reference != "rest" else ["0", "2", "3"] @@ -236,6 +241,7 @@ def test_rank_genes_groups_ttest_pts(reference, method): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) adata_cpu = adata_gpu.copy() # Run with pts=True @@ -297,8 +303,6 @@ def test_rank_genes_groups_ttest_direct_scipy(): Creates a simple two-group dataset and compares rapids_singlecell t-test directly against scipy.stats.ttest_ind without intermediate statistics. """ - import anndata as ad - np.random.seed(42) n_group1, n_group2, n_genes = 50, 60, 20 @@ -308,6 +312,9 @@ def test_rank_genes_groups_ttest_direct_scipy(): # Combine into AnnData X = np.vstack([X_group1, X_group2]) + X -= X.min() + X_group1 = X[:n_group1] + X_group2 = X[n_group1:] obs = {"group": ["A"] * n_group1 + ["B"] * n_group2} adata = ad.AnnData(X=X, obs=obs) adata.obs["group"] = adata.obs["group"].astype("category") @@ -350,6 +357,7 @@ def test_rank_genes_groups_ttest_matches_scipy(): adata = pbmc68k_reduced() # Convert to float64 for maximum precision in comparison adata.X = adata.X.astype(np.float64) + _make_nonnegative(adata) # Run rapids_singlecell t-test rsc.tl.rank_genes_groups(adata, "bulk_labels", method="t-test", use_raw=False) @@ -412,6 +420,7 @@ def test_rank_genes_groups_ttest_mask_var_array(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # Create mask to select only first 5 genes mask = np.array([True] * 5 + [False] * 5) @@ -439,6 +448,7 @@ def test_rank_genes_groups_ttest_mask_var_string(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # Add mask column to adata.var adata.var["highly_variable"] = [True] * 6 + [False] * 4 @@ -465,6 +475,7 @@ def test_rank_genes_groups_ttest_mask_var_matches_scanpy(method): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=8, n_centers=3, n_observations=150) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) adata_cpu = adata_gpu.copy() mask = np.array([True, False, True, False, True, True, False, True]) @@ -497,6 +508,7 @@ def test_rank_genes_groups_ttest_rankby_abs(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) adata_abs = adata.copy() # Run without rankby_abs @@ -524,6 +536,7 @@ def test_rank_genes_groups_ttest_key_added(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) custom_key = "my_custom_key" diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 6e3dbf89..29871ba0 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -60,13 +60,77 @@ def test_rank_genes_groups_sparse_negative_values_raise(method, fmt): rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) -def test_rank_genes_groups_default_lazy_get_df_matches_scanpy(): +@pytest.mark.parametrize( + "method", + ["t-test", "t-test_overestim_var", "wilcoxon", "wilcoxon_binned", "logreg"], +) +@pytest.mark.parametrize("fmt", ["numpy_dense", "cupy_dense"]) +def test_rank_genes_groups_dense_negative_values_raise(method, fmt): + X = np.array( + [ + [-1.0, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + ], + dtype=np.float32, + ) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame( + {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} + ), + var=pd.DataFrame(index=["g0", "g1", "g2"]), + ) + + with pytest.raises(ValueError, match="Dense input contains negative values"): + rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) + + +@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_dense", "cupy_csr"]) +def test_rank_genes_groups_complex_values_raise(fmt): + X = np.array( + [ + [1.0 + 0.0j, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + ], + dtype=np.complex64, + ) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame( + {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} + ), + var=pd.DataFrame(index=["g0", "g1", "g2"]), + ) + + with pytest.raises(TypeError, match="complex expression values"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +def test_device_sparse_int64_indptr_overflow_warns(): + from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( + _device_sparse_arrays_i32_f32, + ) + + class FakeSparse: + data = cp.asarray([1.0], dtype=cp.float32) + indices = cp.asarray([0], dtype=cp.int32) + indptr = cp.asarray([0, np.iinfo(np.int32).max + 1], dtype=cp.int64) + + with pytest.warns(RuntimeWarning, match="requires int32 indptr"): + assert _device_sparse_arrays_i32_f32(FakeSparse()) is None + + +def test_rank_genes_groups_structured_results_get_df_and_h5ad_match_scanpy(tmp_path): np.random.seed(42) - adata_lazy = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=120) - _make_nonnegative(adata_lazy) - adata_lazy.obs["blobs"] = adata_lazy.obs["blobs"].astype("category") - adata_lazy.X = sp.csr_matrix(adata_lazy.X) - adata_cpu = adata_lazy.copy() + adata_rsc = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=120) + _make_nonnegative(adata_rsc) + adata_rsc.obs["blobs"] = adata_rsc.obs["blobs"].astype("category") + adata_rsc.X = sp.csr_matrix(adata_rsc.X) + adata_cpu = adata_rsc.copy() adata_cpu.X = adata_cpu.X.toarray() kw = { @@ -77,22 +141,27 @@ def test_rank_genes_groups_default_lazy_get_df_matches_scanpy(): "tie_correct": True, "n_genes": 4, } - rsc.tl.rank_genes_groups(adata_lazy, **kw) + rsc.tl.rank_genes_groups(adata_rsc, **kw) sc.tl.rank_genes_groups(adata_cpu, **kw) - lazy_result = adata_lazy.uns["rank_genes_groups"] - assert lazy_result["names"].dtype.names == ("0", "2") - assert tuple(lazy_result["names"][0]) == tuple( + rsc_result = adata_rsc.uns["rank_genes_groups"] + assert isinstance(rsc_result["names"], np.ndarray) + assert rsc_result["names"].dtype.names == ("0", "2") + assert tuple(rsc_result["names"][0]) == tuple( adata_cpu.uns["rank_genes_groups"]["names"][0] ) np.testing.assert_array_equal( - lazy_result["names"].copy(), - np.asarray(lazy_result["names"]), + rsc_result["names"].copy(), + np.asarray(rsc_result["names"]), ) - lazy_df = sc.get.rank_genes_groups_df(adata_lazy, group=None) + h5ad_path = tmp_path / "rank_genes_groups.h5ad" + adata_rsc.write_h5ad(h5ad_path) + adata_rsc = sc.read_h5ad(h5ad_path) + + rsc_df = sc.get.rank_genes_groups_df(adata_rsc, group=None) scanpy_df = sc.get.rank_genes_groups_df(adata_cpu, group=None) - pd.testing.assert_frame_equal(lazy_df, scanpy_df) + pd.testing.assert_frame_equal(rsc_df, scanpy_df) def test_rank_genes_groups_return_format_removed(): @@ -168,6 +237,60 @@ def test_rank_genes_groups_wilcoxon_return_u_values(reference, fmt): np.testing.assert_allclose(df["scores"].to_numpy(), expected_sorted) +def test_rank_genes_groups_wilcoxon_dense_edge_cases_match_scipy(): + X = np.array( + [ + [1.0, 5.0, 0.0, 2.0, 1.0], + [2.0, 5.0, 0.0, 2.0, 1.0], + [3.0, 5.0, 1.0, 2.0, 1.0], + [4.0, 5.0, 1.0, 3.0, 2.0], + [5.0, 5.0, 1.0, 3.0, 2.0], + [6.0, 5.0, 2.0, 3.0, 2.0], + [7.0, 5.0, 2.0, 4.0, 3.0], + [8.0, 5.0, 2.0, 4.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "a", "a", "b", "b", "b", "b"]) + adata = sc.AnnData( + X=X, + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=["no_ties", "all_ties", "zero_ties", "mixed", "pairs"]), + ) + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference="b", + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + expected_u = {} + for idx, name in enumerate(adata.var_names): + result = mannwhitneyu( + X[labels == "a", idx], + X[labels == "b", idx], + alternative="two-sided", + method="asymptotic", + use_continuity=True, + ) + expected_u[name] = result.statistic + + np.testing.assert_allclose( + df["scores"].to_numpy(), + np.array([expected_u[name] for name in df["names"]]), + rtol=1e-13, + atol=1e-15, + ) + assert np.isfinite(df["pvals"]).all() + + def test_rank_genes_groups_return_u_values_requires_wilcoxon(): adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) _make_nonnegative(adata) @@ -190,10 +313,10 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars """Test wilcoxon matches scanpy output across configurations.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + _make_nonnegative(adata_gpu) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") if sparse: - _make_nonnegative(adata_gpu) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() @@ -228,12 +351,12 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] - rtol = 1e-6 if field == "logfoldchanges" else 1e-13 + rtol = 1e-13 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) - atol = 1e-6 if field == "logfoldchanges" else 1e-15 + atol = 1e-15 np.testing.assert_allclose(gpu_values, cpu_values, rtol=rtol, atol=atol) params = gpu_result["params"] @@ -283,6 +406,7 @@ def test_rank_genes_groups_wilcoxon_honors_layer_and_use_raw(reference): """Test that layer parameter is respected.""" np.random.seed(42) base = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=150) + _make_nonnegative(base) base.obs["blobs"] = base.obs["blobs"].astype("category") base.layers["signal"] = base.X.copy() @@ -330,6 +454,7 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): """Test group subsetting and bonferroni correction.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=4, n_observations=150) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") groups = ["0", "1", "2"] if reference != "rest" else ["0", "2"] @@ -360,6 +485,7 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): def test_rank_genes_groups_wilcoxon_skip_empty_groups_filters_singletons(): np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=21) + _make_nonnegative(adata) adata.obs["target"] = pd.Categorical( ["ref"] * 10 + ["valid"] * 10 + ["singleton"], categories=["ref", "valid", "singleton", "empty"], @@ -383,6 +509,7 @@ def test_rank_genes_groups_wilcoxon_skip_empty_groups_filters_singletons(): def test_rank_genes_groups_wilcoxon_skip_empty_groups_all_tests_filtered(): np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=11) + _make_nonnegative(adata) adata.obs["target"] = pd.Categorical( ["ref"] * 10 + ["singleton"], categories=["ref", "singleton", "empty"], @@ -434,8 +561,8 @@ def test_wilcoxon_subset_rest_stats_match_scanpy(fmt): gpu_result = adata_gpu.uns["rank_genes_groups"] cpu_result = adata_cpu.uns["rank_genes_groups"] for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): - rtol = 1e-6 if field == "logfoldchanges" else 1e-13 - atol = 1e-6 if field == "logfoldchanges" else 1e-15 + rtol = 1e-13 + atol = 1e-15 for group in gpu_result[field].dtype.names: np.testing.assert_allclose( np.asarray(gpu_result[field][group], dtype=float), @@ -547,7 +674,8 @@ def test_wilcoxon_ovo_host_csr_unsorted_indices_match_sorted(): "cupy_csc", ], ) -def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt): +@pytest.mark.parametrize("pre_load", [False, True]) +def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt, pre_load): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=120) _make_nonnegative(adata_gpu) @@ -563,14 +691,14 @@ def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt): "tie_correct": True, "n_genes": 5, } - rsc.tl.rank_genes_groups(adata_gpu, **kw) + rsc.tl.rank_genes_groups(adata_gpu, **kw, pre_load=pre_load) sc.tl.rank_genes_groups(adata_cpu, **kw) gpu_result = adata_gpu.uns["rank_genes_groups"] cpu_result = adata_cpu.uns["rank_genes_groups"] for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): - rtol = 1e-6 if field == "logfoldchanges" else 1e-13 - atol = 1e-6 if field == "logfoldchanges" else 1e-15 + rtol = 1e-13 + atol = 1e-15 for group in gpu_result[field].dtype.names: np.testing.assert_allclose( np.asarray(gpu_result[field][group], dtype=float), @@ -591,6 +719,7 @@ def test_rank_genes_groups_wilcoxon_with_renamed_categories( """Test with renamed category labels.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") # First run with original category names @@ -622,6 +751,7 @@ def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): """Test that group order doesn't affect results.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") bdata = adata.copy() @@ -661,6 +791,7 @@ def test_rank_genes_groups_wilcoxon_pts(reference, pre_load): """Test that pts (fraction of cells expressing) is computed correctly.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + _make_nonnegative(adata_gpu) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") adata_cpu = adata_gpu.copy() From 73cda5880f6f850486b1455d81af1a224f6ac715 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 13 May 2026 18:03:55 +0200 Subject: [PATCH 07/36] fix tests --- .../tools/_rank_genes_groups/_utils.py | 6 ----- tests/test_rank_genes_groups_wilcoxon.py | 27 ------------------- 2 files changed, 33 deletions(-) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index de91e25d..e9efbc50 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -51,12 +51,6 @@ def _check_sparse_nonnegative(X) -> None: elif cpsp.issparse(X): if X.nnz > 0 and float(X.data.min()) < 0: raise _nonnegative_error("Sparse input") - elif isinstance(X, np.ndarray): - if X.size > 0 and float(np.nanmin(X)) < 0: - raise _nonnegative_error("Dense input") - elif isinstance(X, cp.ndarray): - if X.size > 0 and float(cp.nanmin(X)) < 0: - raise _nonnegative_error("Dense input") def _select_groups( diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 5bef924b..af39da54 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -60,33 +60,6 @@ def test_rank_genes_groups_sparse_negative_values_raise(method, fmt): rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) -@pytest.mark.parametrize( - "method", - ["t-test", "t-test_overestim_var", "wilcoxon", "wilcoxon_binned", "logreg"], -) -@pytest.mark.parametrize("fmt", ["numpy_dense", "cupy_dense"]) -def test_rank_genes_groups_dense_negative_values_raise(method, fmt): - X = np.array( - [ - [-1.0, 0.0, 2.0], - [0.0, 1.0, 0.0], - [2.0, 0.0, 1.0], - [0.0, 3.0, 0.0], - ], - dtype=np.float32, - ) - adata = sc.AnnData( - X=_to_format(X, fmt), - obs=pd.DataFrame( - {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} - ), - var=pd.DataFrame(index=["g0", "g1", "g2"]), - ) - - with pytest.raises(ValueError, match="Dense input contains negative values"): - rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) - - @pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_dense", "cupy_csr"]) def test_rank_genes_groups_complex_values_raise(fmt): X = np.array( From 49d43d0884683b89de0136e6c9a614c5a3e06a3d Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 3 Jun 2026 12:53:17 +0200 Subject: [PATCH 08/36] update --- src/rapids_singlecell/_cuda/nb_types.h | 3 - .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 69 --- .../_cuda/wilcoxon/wilcoxon.cu | 112 +---- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 41 +- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 6 +- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 10 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 14 +- .../_cuda/wilcoxon/wilcoxon_sparse.cu | 178 +++---- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 125 ----- .../tools/_rank_genes_groups/_core.py | 48 +- .../tools/_rank_genes_groups/_wilcoxon.py | 447 ++++++++++-------- tests/test_rank_genes_groups_wilcoxon.py | 74 ++- 12 files changed, 433 insertions(+), 694 deletions(-) diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 36b0db7e..f4daa926 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -104,9 +104,6 @@ using gpu_array_contig = nb::ndarray; template using host_array = nb::ndarray>; -template -using host_array_2d = nb::ndarray>; - // Register bindings for both regular CUDA and managed-memory arrays. // Usage: // template diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 5af4e964..08c25c4d 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -137,72 +137,3 @@ __global__ void rank_sums_from_sorted_kernel( } } } - -/** - * OVR dense rank core. - * - * sorted_vals and sorter are F-order outputs of sorting each column of the - * current dense block. The kernel directly accumulates rank sums per group, - * avoiding a full ranks matrix and a group one-hot matrix multiply. - */ -__global__ void ovr_rank_dense_kernel(const float* __restrict__ sorted_vals, - const int* __restrict__ sorter, - const int* __restrict__ group_codes, - double* __restrict__ rank_sums, - double* __restrict__ tie_corr, int n_rows, - int n_cols, int n_groups, - bool compute_tie_corr) { - int col = blockIdx.x; - if (col >= n_cols) return; - - const float* sv = sorted_vals + (long long)col * n_rows; - const int* si = sorter + (long long)col * n_rows; - - double local_tie = 0.0; - for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { - float val = sv[i]; - - int lo = 0, hi = i; - while (lo < hi) { - int mid = lo + ((hi - lo) >> 1); - if (sv[mid] < val) - lo = mid + 1; - else - hi = mid; - } - int tie_start = lo; - - lo = i; - hi = n_rows - 1; - while (lo < hi) { - int mid = lo + ((hi - lo + 1) >> 1); - if (sv[mid] > val) - hi = mid - 1; - else - lo = mid; - } - int tie_end = lo; - double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; - - int row = si[i]; - int group = group_codes[row]; - if (group >= 0 && group < n_groups) { - atomicAdd(&rank_sums[(size_t)group * n_cols + col], avg_rank); - } - - if (compute_tie_corr && i == tie_end) { - double t = (double)(tie_end - tie_start + 1); - if (t > 1.0) local_tie += t * t * t - t; - } - } - - if (!compute_tie_corr) return; - - __shared__ double warp_buf[32]; - double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); - if (threadIdx.x == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - tie_corr[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } -} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index d314b289..ccca24e7 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -15,19 +15,6 @@ using namespace nb::literals; -static inline void launch_ovr_rank_dense( - const float* sorted_vals, const int* sorter, const int* group_codes, - double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - ovr_rank_dense_kernel<<>>( - sorted_vals, sorter, group_codes, rank_sums, tie_corr, n_rows, n_cols, - n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(ovr_rank_dense_kernel); -} - static void launch_ovr_rank_dense_streaming( const float* block, const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, @@ -149,11 +136,11 @@ static void launch_ovr_rank_dense_streaming( for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); } -static void launch_ovo_rank_dense_tiered_impl( - const float* ref_data, bool ref_is_sorted, const float* grp_data, - const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, - int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols, cudaStream_t upstream_stream) { +static void launch_ovo_rank_dense_tiered_unsorted_ref( + const float* ref_data, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + cudaStream_t upstream_stream) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0 || n_groups == 0) return; if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; @@ -201,7 +188,7 @@ static void launch_ovo_rank_dense_tiered_impl( doff, doff + 1, BEGIN_BIT, END_BIT); } size_t ref_cub_temp_bytes = 0; - if (!ref_is_sorted) { + { auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( @@ -244,15 +231,9 @@ static void launch_ovo_rank_dense_tiered_impl( }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; ++s) { - if (ref_is_sorted) { - bufs[s].ref_sorted = nullptr; - bufs[s].ref_seg_offsets = nullptr; - bufs[s].ref_cub_temp = nullptr; - } else { - bufs[s].ref_sorted = pool.alloc(sub_ref_items); - bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); - } + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); bufs[s].grp_cub_temp = needs_tier3 ? pool.alloc(grp_cub_temp_bytes) : nullptr; bufs[s].ref_tie_sums = @@ -296,15 +277,13 @@ static void launch_ovo_rank_dense_tiered_impl( auto& buf = bufs[s]; const float* ref_sub = ref_data + (size_t)col * n_ref; const float* grp_sub = grp_data + (size_t)col * n_all_grp; - if (!ref_is_sorted) { - upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); - size_t temp = ref_cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.ref_cub_temp, temp, ref_sub, buf.ref_sorted, - sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - ref_sub = buf.ref_sorted; - } + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + size_t ref_temp = ref_cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.ref_cub_temp, ref_temp, ref_sub, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + ref_sub = buf.ref_sorted; int skip_le = 0; bool run_tier0 = t1.use_tier0; @@ -396,52 +375,10 @@ static void launch_ovo_rank_dense_tiered_impl( for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); } -static void launch_ovo_rank_dense_tiered( - const float* ref_sorted, const float* grp_data, const int* grp_offsets, - double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols, - cudaStream_t upstream_stream) { - launch_ovo_rank_dense_tiered_impl(ref_sorted, true, grp_data, grp_offsets, - rank_sums, tie_corr, n_ref, n_all_grp, - n_cols, n_groups, compute_tie_corr, - sub_batch_cols, upstream_stream); -} - -static void launch_ovo_rank_dense_tiered_unsorted_ref( - const float* ref_data, const float* grp_data, const int* grp_offsets, - double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols, - cudaStream_t upstream_stream) { - launch_ovo_rank_dense_tiered_impl(ref_data, false, grp_data, grp_offsets, - rank_sums, tie_corr, n_ref, n_all_grp, - n_cols, n_groups, compute_tie_corr, - sub_batch_cols, upstream_stream); -} - template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - m.def( - "ovo_rank_dense_tiered", - [](gpu_array_f ref_sorted, - gpu_array_f grp_data, - gpu_array_c grp_offsets, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, - std::uintptr_t stream) { - launch_ovo_rank_dense_tiered(ref_sorted.data(), grp_data.data(), - grp_offsets.data(), rank_sums.data(), - tie_corr.data(), n_ref, n_all_grp, - n_cols, n_groups, compute_tie_corr, - sub_batch_cols, (cudaStream_t)stream); - }, - "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, - "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, - "stream"_a = 0); - m.def( "ovo_rank_dense_tiered_unsorted_ref", [](gpu_array_f ref_data, @@ -462,23 +399,6 @@ void register_bindings(nb::module_& m) { "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, "stream"_a = 0); - m.def( - "ovr_rank_dense", - [](gpu_array_f sorted_vals, - gpu_array_f sorter, - gpu_array_c group_codes, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_rows, int n_cols, - int n_groups, bool compute_tie_corr, std::uintptr_t stream) { - launch_ovr_rank_dense(sorted_vals.data(), sorter.data(), - group_codes.data(), rank_sums.data(), - tie_corr.data(), n_rows, n_cols, n_groups, - compute_tie_corr, (cudaStream_t)stream); - }, - "sorted_vals"_a, "sorter"_a, "group_codes"_a, "rank_sums"_a, - "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, - "compute_tie_corr"_a, "stream"_a = 0); - m.def( "ovr_rank_dense_streaming", [](gpu_array_f block, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index ec723b55..15afa8a1 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -304,10 +304,10 @@ static inline void upload_linear_offsets(int* d_offsets, int n_segments, // CSR → dense F-order extraction (templated on data type) // ============================================================================ -template +template __global__ void csr_extract_dense_kernel(const T* __restrict__ data, const int* __restrict__ indices, - const int* __restrict__ indptr, + const IndptrT* __restrict__ indptr, const int* __restrict__ row_ids, T* __restrict__ out, int n_target, int col_start, int col_stop) { @@ -315,52 +315,25 @@ __global__ void csr_extract_dense_kernel(const T* __restrict__ data, if (tid >= n_target) return; int row = row_ids[tid]; - int rs = indptr[row]; - int re = indptr[row + 1]; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; - int lo = rs, hi = re; + IndptrT lo = rs, hi = re; while (lo < hi) { - int m = lo + ((hi - lo) >> 1); + IndptrT m = lo + ((hi - lo) >> 1); if (indices[m] < col_start) lo = m + 1; else hi = m; } - for (int p = lo; p < re; ++p) { + for (IndptrT p = lo; p < re; ++p) { int c = indices[p]; if (c >= col_stop) break; out[(long long)(c - col_start) * n_target + tid] = data[p]; } } -template -__global__ void csr_extract_dense_identity_rows_kernel( - const T* __restrict__ data, const int* __restrict__ indices, - const int* __restrict__ indptr, T* __restrict__ out, int n_target, - int col_start, int col_stop) { - int row = blockIdx.x; - if (row >= n_target) return; - - int rs = indptr[row]; - int re = indptr[row + 1]; - - int lo = rs, hi = re; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (indices[m] < col_start) - lo = m + 1; - else - hi = m; - } - - for (int p = lo + threadIdx.x; p < re; p += blockDim.x) { - int c = indices[p]; - if (c >= col_stop) break; - out[(long long)(c - col_start) * n_target + row] = data[p]; - } -} - template __global__ void csr_extract_dense_identity_rows_unsorted_kernel( const T* __restrict__ data, const int* __restrict__ indices, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index b60b87ff..b53ce348 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -8,8 +8,9 @@ * reference slice. This mirrors the fast host-CSR path and avoids redoing the * reference dense extraction + segmented sort for every column sub-batch. */ +template static void ovo_streaming_csr_impl( - const float* csr_data, const int* csr_indices, const int* csr_indptr, + const float* csr_data, const int* csr_indices, const IndptrT* csr_indptr, const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { @@ -301,8 +302,9 @@ static void ovo_streaming_csr_impl( * Like the CSR variant, but extracts rows via lookup maps so it can operate on * native CSC input without converting the whole matrix. */ +template static void ovo_streaming_csc_impl( - const float* csc_data, const int* csc_indices, const int* csc_indptr, + const float* csc_data, const int* csc_indices, const IndptrT* csc_indptr, const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index 9fd626b6..d75e5785 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -28,20 +28,20 @@ __global__ void build_tier3_seg_begin_end_offsets_kernel( * One block per column, threads scatter matching nonzeros. * Output must be pre-zeroed. */ -template +template __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, const IndexT* __restrict__ indices, - const int* __restrict__ indptr, + const IndptrT* __restrict__ indptr, const int* __restrict__ row_map, float* __restrict__ out, int n_target, int col_start) { int col_local = blockIdx.x; int col = col_start + col_local; - int start = indptr[col]; - int end = indptr[col + 1]; + IndptrT start = indptr[col]; + IndptrT end = indptr[col + 1]; - for (int p = start + threadIdx.x; p < end; p += blockDim.x) { + for (IndptrT p = start + threadIdx.x; p < end; p += blockDim.x) { int out_row = row_map[(int)indices[p]]; if (out_row >= 0) { out[(long long)col_local * n_target + out_row] = data[p]; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 257bbbb3..8dd205c8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -530,16 +530,17 @@ static void ovr_sparse_csr_host_streaming_impl( // Sparse-aware CSC OVR streaming (sort only stored nonzeros) // ============================================================================ +template static void ovr_sparse_csc_streaming_impl( - const float* csc_data, const int* csc_indices, const int* csc_indptr, + const float* csc_data, const int* csc_indices, const IndptrT* csc_indptr, const int* group_codes, const double* group_sizes, double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; // Read indptr to host for batch planning - std::vector h_indptr(n_cols + 1); - cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(int), + std::vector h_indptr(n_cols + 1); + cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(IndptrT), cudaMemcpyDeviceToHost); int n_streams = N_STREAMS; @@ -608,8 +609,8 @@ static void ovr_sparse_csc_streaming_impl( auto stream = streams[s]; auto& buf = bufs[s]; - int ptr_start = h_indptr[col]; - int ptr_end = h_indptr[col + sb_cols]; + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), "OVR device CSC active batch nnz"); @@ -689,8 +690,9 @@ static void ovr_sparse_csc_streaming_impl( * * Compared to the dense CSR path, sort work drops by ~1/sparsity. */ +template static void ovr_sparse_csr_streaming_impl( - const float* csr_data, const int* csr_indices, const int* csr_indptr, + const float* csr_data, const int* csr_indices, const IndptrT* csr_indptr, const int* group_codes, const double* group_sizes, double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index 4316d284..5cf7a067 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -19,47 +19,36 @@ template void register_sparse_bindings(nb::module_& m) { m.doc() = "Sparse-native host Wilcoxon CUDA kernels"; - m.def( - "ovr_sparse_csc_device", - [](gpu_array_c csc_data, - gpu_array_c csc_indices, - gpu_array_c csc_indptr, - gpu_array_c group_codes, - gpu_array_c group_sizes, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_rows, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols) { - ovr_sparse_csc_streaming_impl( - csc_data.data(), csc_indices.data(), csc_indptr.data(), - group_codes.data(), group_sizes.data(), rank_sums.data(), - tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, - "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); +#define RSC_OVR_SPARSE_DEVICE_BINDING(NAME, IMPL, IndptrCType) \ + m.def( \ + NAME, \ + [](gpu_array_c data, \ + gpu_array_c indices, \ + gpu_array_c indptr, \ + gpu_array_c group_codes, \ + gpu_array_c group_sizes, \ + gpu_array_c rank_sums, \ + gpu_array_c tie_corr, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, int sub_batch_cols) { \ + IMPL(data.data(), indices.data(), indptr.data(), \ + group_codes.data(), group_sizes.data(), rank_sums.data(), \ + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, \ + sub_batch_cols); \ + }, \ + "data"_a, "indices"_a, "indptr"_a, "group_codes"_a, "group_sizes"_a, \ + "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, \ + "n_groups"_a, "compute_tie_corr"_a, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) - m.def( - "ovr_sparse_csr_device", - [](gpu_array_c csr_data, - gpu_array_c csr_indices, - gpu_array_c csr_indptr, - gpu_array_c group_codes, - gpu_array_c group_sizes, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_rows, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols) { - ovr_sparse_csr_streaming_impl( - csr_data.data(), csr_indices.data(), csr_indptr.data(), - group_codes.data(), group_sizes.data(), rank_sums.data(), - tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, - "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device", + ovr_sparse_csc_streaming_impl, int); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device_i64", + ovr_sparse_csc_streaming_impl, int64_t); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device", + ovr_sparse_csr_streaming_impl, int); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device_i64", + ovr_sparse_csr_streaming_impl, int64_t); +#undef RSC_OVR_SPARSE_DEVICE_BINDING #define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ m.def( \ @@ -93,18 +82,10 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int, int); RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int, int64_t); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_idx64", float, int64_t, - int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_idx64_i64", float, - int64_t, int64_t); RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int, int); RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, int, int64_t); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_idx64", double, - int64_t, int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_idx64_i64", double, - int64_t, int64_t); #undef RSC_OVR_SPARSE_CSC_HOST_BINDING #define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ @@ -139,65 +120,44 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64", float, int, int64_t); - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64", float, int64_t, - int); - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64_i64", float, - int64_t, int64_t); RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64", double, int, int); RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, int64_t); - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64", double, - int64_t, int); - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64_i64", double, - int64_t, int64_t); #undef RSC_OVR_SPARSE_CSR_HOST_BINDING - m.def( - "ovo_streaming_csc_device", - [](gpu_array_c csc_data, - gpu_array_c csc_indices, - gpu_array_c csc_indptr, - gpu_array_c ref_row_map, - gpu_array_c grp_row_map, - gpu_array_c grp_offsets, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - ovo_streaming_csc_impl( - csc_data.data(), csc_indices.data(), csc_indptr.data(), - ref_row_map.data(), grp_row_map.data(), grp_offsets.data(), - rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, sub_batch_cols); - }, - "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "ref_row_map"_a, - "grp_row_map"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, - nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, - "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming_csr_device", - [](gpu_array_c csr_data, - gpu_array_c csr_indices, - gpu_array_c csr_indptr, - gpu_array_c ref_row_ids, - gpu_array_c grp_row_ids, - gpu_array_c grp_offsets, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - ovo_streaming_csr_impl( - csr_data.data(), csr_indices.data(), csr_indptr.data(), - ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), - rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, sub_batch_cols); - }, - "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "ref_row_ids"_a, - "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, - nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, - "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); +#define RSC_OVO_DEVICE_BINDING(NAME, IMPL, IndptrCType) \ + m.def( \ + NAME, \ + [](gpu_array_c data, \ + gpu_array_c indices, \ + gpu_array_c indptr, \ + gpu_array_c ref_rows, \ + gpu_array_c grp_rows, \ + gpu_array_c grp_offsets, \ + gpu_array_c rank_sums, \ + gpu_array_c tie_corr, int n_ref, int n_all_grp, \ + int n_cols, int n_groups, bool compute_tie_corr, \ + int sub_batch_cols) { \ + IMPL(data.data(), indices.data(), indptr.data(), ref_rows.data(), \ + grp_rows.data(), grp_offsets.data(), rank_sums.data(), \ + tie_corr.data(), n_ref, n_all_grp, n_cols, n_groups, \ + compute_tie_corr, sub_batch_cols); \ + }, \ + "data"_a, "indices"_a, "indptr"_a, "ref_rows"_a, "grp_rows"_a, \ + "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), \ + "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, \ + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device", ovo_streaming_csc_impl, + int); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device_i64", + ovo_streaming_csc_impl, int64_t); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device", ovo_streaming_csr_impl, + int); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device_i64", + ovo_streaming_csr_impl, int64_t); +#undef RSC_OVO_DEVICE_BINDING #define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ m.def( \ @@ -235,17 +195,9 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int, int64_t); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64", float, int64_t, - int); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64_i64", float, int64_t, - int64_t); RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, int64_t); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64", double, - int64_t, int); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64_i64", double, - int64_t, int64_t); #undef RSC_OVO_CSC_HOST_BINDING #define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ @@ -283,17 +235,9 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int, int64_t); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64", float, int64_t, - int); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64_i64", float, int64_t, - int64_t); RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, int64_t); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64", double, - int64_t, int); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64_i64", double, - int64_t, int64_t); #undef RSC_OVO_CSR_HOST_BINDING } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index efdac894..3fefdc99 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -239,97 +239,6 @@ static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, return 0; } -/** - * Pre-sort cast-and-accumulate kernel for dense OVR host streaming. - * - * Reads a sub-batch block in its native host dtype (InT = float or double), - * writes a float32 copy used as the sort input, and accumulates per-group - * sum, sum-of-squares and nonzero counts in float64. Stats are derived - * from the original-precision values so float64 host input keeps its - * precision while the sort still runs on float32 keys. - * - * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). - * Shared memory: 3 * n_groups doubles (s_sum, s_sq, s_nnz). - */ -template -__global__ void ovr_cast_and_accumulate_dense_kernel( - const InT* __restrict__ block_in, float* __restrict__ block_f32_out, - const int* __restrict__ group_codes, double* __restrict__ group_sums, - double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, - int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, - bool compute_nnz = true) { - int col = blockIdx.x; - if (col >= sb_cols) return; - - extern __shared__ double smem[]; - double* s_sum = smem; - double* s_sq = smem + n_groups; - double* s_nnz = smem + 2 * n_groups; - - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - s_sum[g] = 0.0; - if (compute_sq_sums) s_sq[g] = 0.0; - if (compute_nnz) s_nnz[g] = 0.0; - } - __syncthreads(); - - const InT* src = block_in + (size_t)col * n_rows; - float* dst = block_f32_out + (size_t)col * n_rows; - - for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { - InT v_in = src[r]; - double v = (double)v_in; - dst[r] = (float)v_in; - int g = group_codes[r]; - if (g < n_groups) { - atomicAdd(&s_sum[g], v); - if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); - if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); - } - } - __syncthreads(); - - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - group_sums[(size_t)g * sb_cols + col] = s_sum[g]; - if (compute_sq_sums) { - group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; - } - if (compute_nnz) { - group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; - } - } -} - -template -__global__ void ovr_cast_and_accumulate_dense_global_kernel( - const InT* __restrict__ block_in, float* __restrict__ block_f32_out, - const int* __restrict__ group_codes, double* __restrict__ group_sums, - double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, - int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, - bool compute_nnz = true) { - int col = blockIdx.x; - if (col >= sb_cols) return; - - const InT* src = block_in + (size_t)col * n_rows; - float* dst = block_f32_out + (size_t)col * n_rows; - - for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { - InT v_in = src[r]; - double v = (double)v_in; - dst[r] = (float)v_in; - int g = group_codes[r]; - if (g < n_groups) { - atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); - if (compute_sq_sums) { - atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); - } - if (compute_nnz && v != 0.0) { - atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); - } - } - } -} - /** * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. * @@ -425,40 +334,6 @@ __global__ void ovr_cast_and_accumulate_sparse_global_kernel( } } -template -static void launch_ovr_cast_and_accumulate_dense( - const InT* d_block_orig, float* d_block_f32, const int* d_group_codes, - double* d_group_sums, double* d_group_sq_sums, double* d_group_nnz, - int n_rows, int sb_cols, int n_groups, bool compute_sq_sums, - bool compute_nnz, int tpb, size_t smem_cast, bool use_gmem, - cudaStream_t stream) { - if (use_gmem) { - size_t stats_items = (size_t)n_groups * sb_cols; - cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); - if (compute_sq_sums) { - cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), - stream); - } - if (compute_nnz) { - cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), - stream); - } - ovr_cast_and_accumulate_dense_global_kernel - <<>>( - d_block_orig, d_block_f32, d_group_codes, d_group_sums, - d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, - compute_sq_sums, compute_nnz); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_global_kernel); - } else { - ovr_cast_and_accumulate_dense_kernel - <<>>( - d_block_orig, d_block_f32, d_group_codes, d_group_sums, - d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, - compute_sq_sums, compute_nnz); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); - } -} - template static void launch_ovr_cast_and_accumulate_sparse( const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index 7d5b0665..ba144f4f 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -48,7 +48,7 @@ const int n_groups, const bool compute_nnz ) { - const long long idx = blockIdx.x * blockDim.x + threadIdx.x; + const long long idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; const long long total = static_cast(n_rows) * n_cols; if (idx >= total) { return; @@ -180,7 +180,6 @@ def __init__( self.vars_rest: np.ndarray | None = None self.pts_rest: np.ndarray | None = None - self.stats: pd.DataFrame | None = None self.stats_arrays: dict[str, object] | None = None self._store_wilcoxon_gpu_result = False self._wilcoxon_gpu_result: ( @@ -540,7 +539,6 @@ def compute_statistics( "var_names": np.asarray(self.var_names), "gene_indices": np.empty((0, n_genes_user), dtype=np.intp), } - self.stats = None return if self._wilcoxon_gpu_result is not None: @@ -634,6 +632,22 @@ def _fdr_bh_matrix_gpu(pvals: cp.ndarray) -> cp.ndarray: cp.put_along_axis(corrected, order, corrected_sorted, axis=1) return corrected + def _logfoldchanges_into( + self, arrays: dict, group_indices: np.ndarray, top_idx: np.ndarray + ) -> None: + mean_group = self.means[group_indices] + if self.ireference is None: + mean_rest = self.means_rest[group_indices] + else: + mean_rest = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_group) + EPS) / ( + self.expm1_func(mean_rest) + EPS + ) + logfoldchanges = np.log2(foldchanges) + arrays["logfoldchanges"] = np.take_along_axis( + logfoldchanges, top_idx, axis=1 + ).astype(np.float32, copy=False) + def _compute_statistics_arrays( self, test_results: list[tuple[int, NDArray, NDArray]], @@ -673,21 +687,9 @@ def _compute_statistics_arrays( arrays["pvals_adj"] = np.take_along_axis(pvals_adj, top_idx, axis=1) if self.means is not None: - mean_group = self.means[group_indices] - if self.ireference is None: - mean_rest = self.means_rest[group_indices] - else: - mean_rest = self.means[self.ireference][None, :] - foldchanges = (self.expm1_func(mean_group) + EPS) / ( - self.expm1_func(mean_rest) + EPS - ) - logfoldchanges = np.log2(foldchanges) - arrays["logfoldchanges"] = np.take_along_axis( - logfoldchanges, top_idx, axis=1 - ).astype(np.float32, copy=False) + self._logfoldchanges_into(arrays, group_indices, top_idx) self.stats_arrays = arrays - self.stats = None def _compute_statistics_gpu_arrays( self, @@ -740,18 +742,6 @@ def _compute_statistics_gpu_arrays( ) ) elif self.means is not None: - mean_group = self.means[group_indices] - if self.ireference is None: - mean_rest = self.means_rest[group_indices] - else: - mean_rest = self.means[self.ireference][None, :] - foldchanges = (self.expm1_func(mean_group) + EPS) / ( - self.expm1_func(mean_rest) + EPS - ) - logfoldchanges = np.log2(foldchanges) - arrays["logfoldchanges"] = np.take_along_axis( - logfoldchanges, top_idx, axis=1 - ).astype(np.float32, copy=False) + self._logfoldchanges_into(arrays, group_indices, top_idx) self.stats_arrays = arrays - self.stats = None diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 880da7e0..2cef7665 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -258,7 +258,77 @@ def _wilcoxon_scores( return rank_sums - n_group * (n_group + 1.0) / 2.0 -def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool): +def _z_scores_pvals( + rank_sums: cp.ndarray, + expected: cp.ndarray, + variance: cp.ndarray, + sizes: cp.ndarray, + *, + use_continuity: bool, + return_u_values: bool, +) -> tuple[cp.ndarray, cp.ndarray]: + """Shared Wilcoxon normal-approximation epilogue -> (scores, p_values).""" + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores(rank_sums, sizes, z, return_u_values=return_u_values) + return scores, p_values + + +def _ovr_z_pvals( + rank_sums: cp.ndarray, + group_sizes_dev: cp.ndarray, + rest_sizes: cp.ndarray, + n_cells: int, + tie_corr: cp.ndarray, + *, + use_continuity: bool, + return_u_values: bool, +) -> tuple[cp.ndarray, cp.ndarray]: + """Group-vs-rest scores/p-values (tie_corr is ones when not correcting).""" + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + variance *= (n_cells + 1) / 12.0 + return _z_scores_pvals( + rank_sums, + expected, + variance, + group_sizes_dev, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + + +def _ovo_z_pvals( + rank_sums: cp.ndarray, + test_sizes: cp.ndarray, + n_ref: int, + tie_corr_arr: cp.ndarray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> tuple[cp.ndarray, cp.ndarray]: + """Group-vs-reference scores/p-values from rank sums and tie correction.""" + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr_arr + return _z_scores_pvals( + rank_sums, + expected, + variance, + test_sizes, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + + +def _host_sparse_fn_and_arrays(module, base_name: str, X): data_dtype = np.dtype(X.data.dtype) if data_dtype == np.float64: is_f64 = True @@ -273,21 +343,20 @@ def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool ) raise TypeError(msg) - is_idx64 = support_idx64 and X.indices.dtype == np.int64 + # Row/column indices always fit int32 (cells and genes are < 2^31); only the + # indptr (cumulative nnz) can need int64. Mirrors the rest of the sparse code. is_i64 = X.indptr.dtype == np.int64 suffix = "" if is_f64: suffix += "_f64" - if is_idx64: - suffix += "_idx64" if is_i64: suffix += "_i64" fn = getattr(module, base_name + suffix) - indices_arr = X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) + indices_arr = X.indices.astype(np.int32, copy=False) return fn, data_arr, indices_arr -def _device_sparse_arrays_i32_f32(X): +def _device_sparse_arrays_f32(X): data_dtype = np.dtype(X.data.dtype) if data_dtype == np.float32 or data_dtype == np.float64: pass @@ -300,23 +369,23 @@ def _device_sparse_arrays_i32_f32(X): ) raise TypeError(msg) - if X.indptr.dtype != cp.int32: - max_indptr = int(cp.asnumpy(X.indptr[-1])) - if max_indptr > np.iinfo(np.int32).max: - warnings.warn( - "Wilcoxon device sparse path requires int32 indptr for CUDA " - "kernels; falling back to the bounded dense chunk path because " - f"nnz={max_indptr} exceeds int32.", - RuntimeWarning, - stacklevel=3, - ) - return None data = X.data.astype(cp.float32, copy=False) + # Row/column indices fit int32 (cells and genes are < 2^31); indptr + # (cumulative nnz) may need int64, which the *_i64 device kernels handle. indices = X.indices.astype(cp.int32, copy=False) - indptr = X.indptr.astype(cp.int32, copy=False) + if X.indptr.dtype == cp.int64: + indptr = X.indptr + else: + indptr = X.indptr.astype(cp.int32, copy=False) return data, indices, indptr +def _device_sparse_fn(module, base_name: str, indptr: cp.ndarray): + """Select the device kernel binding, using the int64-indptr variant if needed.""" + suffix = "_i64" if indptr.dtype == cp.int64 else "" + return getattr(module, base_name + suffix) + + def _column_totals_for_host_matrix( X, *, compute_sq_sums: bool, compute_nnz: bool ) -> tuple[cp.ndarray, cp.ndarray | None, cp.ndarray | None]: @@ -503,7 +572,7 @@ def _wilcoxon_vs_rest( csc = csc.copy() csc.sort_indices() csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovr_sparse_csc_host", csc, support_idx64=True + _wcs, "ovr_sparse_csc_host", csc ) csc_host_fn( data_arr, @@ -530,7 +599,7 @@ def _wilcoxon_vs_rest( csr = csr.copy() csr.sort_indices() csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovr_sparse_csr_host", csr, support_idx64=True + _wcs, "ovr_sparse_csr_host", csr ) csr_host_fn( data_arr, @@ -573,82 +642,74 @@ def _wilcoxon_vs_rest( total_nnz=total_nnz, ) - expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 - variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] - variance *= (n_cells + 1) / 12.0 - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / cp.sqrt(variance) - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - scores_host = _wilcoxon_scores( - rank_sums, group_sizes_dev, z, return_u_values=return_u_values - ).get() + scores, p_values = _ovr_z_pvals( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + scores_host = scores.get() p_host = p_values.get() return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): - sparse_arrays = _device_sparse_arrays_i32_f32(X) - if sparse_arrays is not None: - data, indices, indptr = sparse_arrays - group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) - group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) - rest_sizes = n_cells - group_sizes_dev - rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) - tie_corr = cp.ones(n_total_genes, dtype=cp.float64) - if cpsp.isspmatrix_csc(X): - _wcs.ovr_sparse_csc_device( - data, - indices, - indptr, - group_codes_gpu, - group_sizes_dev, - rank_sums, - tie_corr, - n_rows=n_cells, - n_cols=n_total_genes, - n_groups=n_groups, - compute_tie_corr=tie_correct, - sub_batch_cols=OVR_DEVICE_CSC_SUB_BATCH, - ) - else: - sparse_X = X - if not sparse_X.has_sorted_indices: - sparse_X = sparse_X.copy() - sparse_X.sort_indices() - data, indices, indptr = _device_sparse_arrays_i32_f32(sparse_X) - _wcs.ovr_sparse_csr_device( - data, - indices, - indptr, - group_codes_gpu, - group_sizes_dev, - rank_sums, - tie_corr, - n_rows=n_cells, - n_cols=n_total_genes, - n_groups=n_groups, - compute_tie_corr=tie_correct, - sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, - ) - - expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 - variance = ( - tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + data, indices, indptr = _device_sparse_arrays_f32(X) + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + if cpsp.isspmatrix_csc(X): + _device_sparse_fn(_wcs, "ovr_sparse_csc_device", indptr)( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSC_SUB_BATCH, + ) + else: + sparse_X = X + if not sparse_X.has_sorted_indices: + sparse_X = sparse_X.copy() + sparse_X.sort_indices() + data, indices, indptr = _device_sparse_arrays_f32(sparse_X) + _device_sparse_fn(_wcs, "ovr_sparse_csr_device", indptr)( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, ) - variance *= (n_cells + 1) / 12.0 - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / cp.sqrt(variance) - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - scores_host = _wilcoxon_scores( - rank_sums, group_sizes_dev, z, return_u_values=return_u_values - ).get() - p_host = p_values.get() - return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + + scores, p_values = _ovr_z_pvals( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + scores_host = scores.get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) @@ -697,18 +758,14 @@ def _wilcoxon_vs_rest( sub_batch_cols=OVR_DENSE_SUB_BATCH, stream=cp.cuda.get_current_stream().ptr, ) - expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 - variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] - variance *= (n_cells + 1) / 12.0 - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - scores = _wilcoxon_scores( - rank_sums, group_sizes_dev, z, return_u_values=return_u_values + scores, p_values = _ovr_z_pvals( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, ) scores_host = scores.get() @@ -835,7 +892,7 @@ def _wilcoxon_with_reference( grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovo_streaming_csc_host", csc, support_idx64=True + _wcs, "ovo_streaming_csc_host", csc ) csc_host_fn( data_arr, @@ -866,7 +923,7 @@ def _wilcoxon_with_reference( # Host CSR gather scans each row's native index list and tolerates # unsorted row indices; avoid a full CSR copy just to sort. csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovo_streaming_csr_host", csr, support_idx64=True + _wcs, "ovo_streaming_csr_host", csr ) csr_host_fn( data_arr, @@ -915,19 +972,14 @@ def _wilcoxon_with_reference( compute_vars=compute_vars, ) - n_combined = test_sizes + n_ref - expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 - variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 - if tie_correct: - variance = variance * tie_corr_arr - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / cp.sqrt(variance) - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - scores = _wilcoxon_scores( - rank_sums, test_sizes, z, return_u_values=return_u_values + scores, p_values = _ovo_z_pvals( + rank_sums, + test_sizes, + n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, ) if rg._store_wilcoxon_gpu_result: rg._wilcoxon_gpu_result = ( @@ -949,80 +1001,73 @@ def _wilcoxon_with_reference( if cpsp.isspmatrix_csr(sparse_X) and not sparse_X.has_sorted_indices: sparse_X = sparse_X.copy() sparse_X.sort_indices() - sparse_arrays = _device_sparse_arrays_i32_f32(sparse_X) - if sparse_arrays is not None: - data, indices, indptr = sparse_arrays - offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) - rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) - tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) - - if cpsp.isspmatrix_csc(sparse_X): - ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) - ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) - grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) - grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) - _wcs.ovo_streaming_csc_device( - data, - indices, - indptr, - cp.asarray(ref_row_map), - cp.asarray(grp_row_map), - offsets_gpu, - rank_sums, - tie_corr_arr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_total_genes, - n_groups=n_test, - compute_tie_corr=tie_correct, - sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, - ) - else: - _wcs.ovo_streaming_csr_device( - data, - indices, - indptr, - cp.asarray(ref_row_ids, dtype=cp.int32), - cp.asarray(all_grp_row_ids, dtype=cp.int32), - offsets_gpu, - rank_sums, - tie_corr_arr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_total_genes, - n_groups=n_test, - compute_tie_corr=tie_correct, - sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, - ) + data, indices, indptr = _device_sparse_arrays_f32(sparse_X) + offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) - n_combined = test_sizes + n_ref - expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 - variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 - if tie_correct: - variance = variance * tie_corr_arr - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / cp.sqrt(variance) - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - scores = _wilcoxon_scores( - rank_sums, test_sizes, z, return_u_values=return_u_values + if cpsp.isspmatrix_csc(sparse_X): + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) + _device_sparse_fn(_wcs, "ovo_streaming_csc_device", indptr)( + data, + indices, + indptr, + cp.asarray(ref_row_map), + cp.asarray(grp_row_map), + offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) + else: + _device_sparse_fn(_wcs, "ovo_streaming_csr_device", indptr)( + data, + indices, + indptr, + cp.asarray(ref_row_ids, dtype=cp.int32), + cp.asarray(all_grp_row_ids, dtype=cp.int32), + offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, ) - if rg._store_wilcoxon_gpu_result: - rg._wilcoxon_gpu_result = ( - np.asarray(test_group_indices, dtype=np.intp), - scores, - p_values, - None, - ) - return [] - scores_host = scores.get() - p_host = p_values.get() - return [ - (group_index, scores_host[slot], p_host[slot]) - for slot, group_index in enumerate(test_group_indices) - ] + + scores, p_values = _ovo_z_pvals( + rank_sums, + test_sizes, + n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + None, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) @@ -1067,20 +1112,14 @@ def _wilcoxon_with_reference( stream=cp.cuda.get_current_stream().ptr, ) - n_combined = test_sizes + n_ref - expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 - variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 - if tie_correct: - variance = variance * tie_corr - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - scores = _wilcoxon_scores( - rank_sums, test_sizes, z, return_u_values=return_u_values + scores, p_values = _ovo_z_pvals( + rank_sums, + test_sizes, + n_ref, + tie_corr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, ) scores_host[:, start:stop] = scores.get() diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index af39da54..7de97082 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -83,9 +83,11 @@ def test_rank_genes_groups_complex_values_raise(fmt): rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) -def test_device_sparse_int64_indptr_overflow_warns(): +def test_device_sparse_int64_indptr_selects_i64_kernel(): + from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _device_sparse_arrays_i32_f32, + _device_sparse_arrays_f32, + _device_sparse_fn, ) class FakeSparse: @@ -93,8 +95,72 @@ class FakeSparse: indices = cp.asarray([0], dtype=cp.int32) indptr = cp.asarray([0, np.iinfo(np.int32).max + 1], dtype=cp.int64) - with pytest.warns(RuntimeWarning, match="requires int32 indptr"): - assert _device_sparse_arrays_i32_f32(FakeSparse()) is None + # int64 indptr is preserved (no truncation, no dense fallback); the row + # indices stay int32 because cells/genes are always < 2^31. + data, indices, indptr = _device_sparse_arrays_f32(FakeSparse()) + assert indptr.dtype == cp.int64 + assert indices.dtype == cp.int32 + assert data.dtype == cp.float32 + + # Dispatch routes an int64 indptr to the *_i64 kernel binding, int32 to base. + assert ( + _device_sparse_fn(_wcs, "ovr_sparse_csc_device", indptr) + is _wcs.ovr_sparse_csc_device_i64 + ) + assert ( + _device_sparse_fn(_wcs, "ovr_sparse_csc_device", indices) + is _wcs.ovr_sparse_csc_device + ) + + +@pytest.mark.parametrize("layout", ["csc", "csr"]) +def test_device_ovr_sparse_i64_indptr_matches_i32(layout): + # cupyx coerces small matrices to int32 indptr, so int64 support is only + # reachable for nnz > 2^31. Exercise the int64-templated kernels directly + # with a hand-built int64 indptr and assert bit-parity with the int32 path. + from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs + + rng = np.random.default_rng(0) + n_rows, n_cols, n_groups = 120, 10, 4 + dense = np.abs(rng.standard_normal((n_rows, n_cols))).astype(np.float32) + dense[dense < 0.6] = 0.0 + mat = sp.csc_matrix(dense) if layout == "csc" else sp.csr_matrix(dense) + mat.sort_indices() + gcodes = rng.integers(0, n_groups, n_rows).astype(np.int32) + gsizes = np.bincount(gcodes, minlength=n_groups).astype(np.float64) + + data = cp.asarray(mat.data, dtype=cp.float32) + indices = cp.asarray(mat.indices, dtype=cp.int32) + g = cp.asarray(gcodes) + gs = cp.asarray(gsizes) + base = getattr(_wcs, f"ovr_sparse_{layout}_device") + i64 = getattr(_wcs, f"ovr_sparse_{layout}_device_i64") + + def run(indptr_dtype, fn): + indptr = cp.asarray(mat.indptr, dtype=indptr_dtype) + rs = cp.empty((n_groups, n_cols), dtype=cp.float64) + tc = cp.ones(n_cols, dtype=cp.float64) + fn( + data, + indices, + indptr, + g, + gs, + rs, + tc, + n_rows=n_rows, + n_cols=n_cols, + n_groups=n_groups, + compute_tie_corr=True, + sub_batch_cols=64, + ) + cp.cuda.get_current_stream().synchronize() + return rs.get(), tc.get() + + rs32, tc32 = run(cp.int32, base) + rs64, tc64 = run(cp.int64, i64) + np.testing.assert_array_equal(rs32, rs64) + np.testing.assert_array_equal(tc32, tc64) def test_rank_genes_groups_structured_results_get_df_and_h5ad_match_scanpy(tmp_path): From 4e4b55dca75174c9ee9ad92307f112f11de3dac5 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 3 Jun 2026 16:45:29 +0200 Subject: [PATCH 09/36] safety commit --- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 5 + .../_cuda/wilcoxon/wilcoxon.cu | 27 +--- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 57 +++++++- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 32 ++--- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 35 ++--- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 21 +++ .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 25 +++- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 37 ++--- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 23 ++++ tests/test_rank_genes_groups_wilcoxon.py | 129 ++++++++++++++++++ 10 files changed, 291 insertions(+), 100 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 08c25c4d..80e78de4 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -30,6 +30,11 @@ __device__ __forceinline__ double wilcoxon_block_sum(double val, * SortPairs. One block owns one column, walks tie runs, and accumulates the * average ranks per group without materializing a full rank matrix. */ +// Dense OVR rank kernel. One block per column; walks sorted tie runs and +// accumulates average ranks per group without materializing a rank matrix. +// The `use_gmem` flag (set by ovr_smem_config) selects shared- vs +// global-memory group accumulators -- CRITICAL: the use_gmem path is REQUIRED +// when n_groups is large (does NOT fit in smem) and must not be removed. __global__ void rank_sums_from_sorted_kernel( const float* __restrict__ sorted_vals, const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index ccca24e7..f0351d22 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -30,14 +30,8 @@ static void launch_ovr_rank_dense_streaming( size_t sub_items = (size_t)n_rows * sub_batch_cols; int sub_items_i32 = checked_cub_items(sub_items, "Dense OVR sub-batch"); - size_t cub_temp_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, sub_items_i32, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); - } + size_t cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(sub_items_i32, sub_batch_cols); std::vector streams(n_streams); for (int i = 0; i < n_streams; ++i) { @@ -181,20 +175,11 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( int max_grp_seg = checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, "Dense OVO group segment count"); - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, grp_cub_temp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); - } - size_t ref_cub_temp_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_temp_bytes, fk, fk, sub_ref_items_i32, - sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + grp_cub_temp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); } + size_t ref_cub_temp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); std::vector streams(n_streams); for (int i = 0; i < n_streams; ++i) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 15afa8a1..9de92d62 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -8,6 +8,8 @@ #include +#include + #include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR void* wilcoxon_rmm_allocate(size_t bytes); @@ -49,13 +51,60 @@ constexpr int TIER1_GROUP_THRESHOLD = 2500; // 512 MB per stream dense slab + same for sorted copy ≈ 1 GB / stream. constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; +// Query CUB device-segmented-radix-sort scratch size with a dummy launch. +// Every Wilcoxon sort uses float keys and (for SortPairs) int values/offsets. +static inline size_t cub_segmented_sortkeys_temp_bytes(int num_items, + int num_segments) { + size_t bytes = 0; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, fk, fk, num_items, + num_segments, doff, doff + 1, + BEGIN_BIT, END_BIT); + return bytes; +} + +template +static inline size_t cub_segmented_sortpairs_temp_bytes(int num_items, + int num_segments) { + size_t bytes = 0; + auto* fk = reinterpret_cast(1); + auto* v = reinterpret_cast(1); + auto* off = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs(nullptr, bytes, fk, fk, v, v, + num_items, num_segments, off, + off + 1, BEGIN_BIT, END_BIT); + return bytes; +} + +// Universal CUDA static per-block shared-memory floor; safe fallback if the +// device query fails. +constexpr size_t WILCOXON_FALLBACK_SMEM_PER_BLOCK = 48 * 1024; + +// CRITICAL device-limit query that powers every smem/gmem and tier decision. +// Returns the per-block shared-memory limit (cached per device). Consumed by +// ovr_smem_config, sparse_ovr_smem_config, cast_accumulate_smem_config, and +// make_tier1_config to decide when accumulators/sorts no longer fit in smem and +// must fall back to global memory or CUB. DO NOT hardcode a smem value in place +// of this call -- the gmem-fallback thresholds (e.g. sparse OVR ~3056 groups) +// auto-scale with the GPU because of it; falls back to 48 KB if the query +// fails. static inline size_t wilcoxon_max_smem_per_block() { int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { + return WILCOXON_FALLBACK_SMEM_PER_BLOCK; + } + static thread_local int cached_dev = -1; + static thread_local size_t cached_smem = 0; + if (device == cached_dev) return cached_smem; int max_smem = 0; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, - device); - return (size_t)max_smem; + if (cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, + device) != cudaSuccess) { + return WILCOXON_FALLBACK_SMEM_PER_BLOCK; + } + cached_dev = device; + cached_smem = (size_t)max_smem; + return cached_smem; } static inline int checked_cub_items(size_t count, const char* context) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index b53ce348..1a9215a4 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -70,11 +70,8 @@ static void ovo_streaming_csr_impl( int max_grp_seg = checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, "OVO device CSR group segment count"); - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); + cub_grp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); cub_temp_bytes = cub_grp_bytes; } @@ -158,12 +155,8 @@ static void ovo_streaming_csr_impl( upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, ref_stream); - size_t ref_cub_bytes = 0; - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_bytes, fk, fk, cache_ref_items_i32, cache_cols, - doff, doff + 1, BEGIN_BIT, END_BIT); + size_t ref_cub_bytes = + cub_segmented_sortkeys_temp_bytes(cache_ref_items_i32, cache_cols); ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); size_t ref_temp = ref_cub_bytes; cub::DeviceSegmentedRadixSort::SortKeys( @@ -339,25 +332,16 @@ static void ovo_streaming_csc_impl( int sub_grp_items_i32 = checked_cub_items(sub_grp_items, "OVO device CSC group sub-batch"); - size_t cub_ref_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, - doff, doff + 1, BEGIN_BIT, END_BIT); - } + size_t cub_ref_bytes = + cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); size_t cub_temp_bytes = cub_ref_bytes; if (needs_tier3) { size_t cub_grp_bytes = 0; int max_grp_seg = checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, "OVO device CSC group segment count"); - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); + cub_grp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 53f27bbe..fc91821b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -46,25 +46,15 @@ static void ovo_streaming_csc_host_impl( checked_cub_items(sub_grp_items, "OVO host CSC group sub-batch"); // CUB temp - size_t cub_ref_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, - doff, doff + 1, BEGIN_BIT, END_BIT); - } + size_t cub_ref_bytes = + cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); size_t cub_temp_bytes = cub_ref_bytes; if (needs_tier3) { - size_t cub_grp_bytes = 0; int max_grp_seg = checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, "OVO host CSC group segment count"); - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); + size_t cub_grp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } @@ -629,14 +619,8 @@ static void ovo_streaming_csr_host_impl( } // Segmented sort ref_dense by column → ref_sorted - size_t ref_cub_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_bytes, fk, fk, ref_items_i32, n_cols, doff, - doff + 1, BEGIN_BIT, END_BIT); - } + size_t ref_cub_bytes = + cub_segmented_sortkeys_temp_bytes(ref_items_i32, n_cols); ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); upload_linear_offsets(d_ref_seg, n_cols, n_ref, ref_stream); size_t temp = ref_cub_bytes; @@ -662,14 +646,11 @@ static void ovo_streaming_csr_host_impl( if (may_need_cub && max_sub_items > 0) { int max_sub_items_i32 = checked_cub_items(max_sub_items, "OVO host CSR group pack"); - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); int max_segments = checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, "OVO host CSR max group segment count"); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, max_sub_items_i32, max_segments, - doff, doff + 1, BEGIN_BIT, END_BIT); + cub_grp_bytes = + cub_segmented_sortkeys_temp_bytes(max_sub_items_i32, max_segments); } std::vector streams(n_streams); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index d75e5785..c93ceccf 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -74,6 +74,19 @@ struct Tier1Config { size_t tier1_smem = 0; }; +// SINGLE source of truth for OVO tier dispatch, used by EVERY OVO path: +// dense (wilcoxon.cu) AND all four sparse OVO impls (host/device CSC/CSR), +// which extract ref+group rows to dense per sub-batch and then call this. +// Scans per-group sizes once and returns which size-gated tiers to co-launch: +// Tier 0 (<=32, TIER0_GROUP_THRESHOLD): ovo_warp_sort_rank_kernel +// Tier 0.5 (<=64, TIER0_64_GROUP_THRESHOLD): ovo_small64_sort_rank_kernel +// Tier 2 (<=512, TIER2_GROUP_THRESHOLD): ovo_medium_unsorted_rank_kernel +// Tier 1 (<=2500, TIER1_GROUP_THRESHOLD): ovo_fused_sort_rank_kernel (smem +// sort) Tier 3 (>2500): CUB segmented sort + +// batched_rank_sums_presorted_kernel +// Tiers cooperate via skip_n_grp_le: a larger tier skips groups a smaller tier +// already handled. Tier 1 is device-adapted: if its fused-sort smem footprint +// would exceed wilcoxon_max_smem_per_block() it is disabled in favor of Tier 3. static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { Tier1Config c; c.min_grp_size = INT_MAX; @@ -104,6 +117,14 @@ static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + WARP_REDUCE_BUF * sizeof(double); + // Adapt to the device: if the fused-sort buffer would exceed the + // per-block shared-memory limit, fall back to the tier-3 CUB segmented + // sort (which has no smem cap) rather than launching a kernel that + // would fail. Never triggers at the current threshold (~16.6KB), but + // keeps the dispatch correct if the threshold or device limit changes. + if (c.tier1_smem > wilcoxon_max_smem_per_block()) { + c.use_tier1 = false; + } } return c; } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index 2323e27f..2b282b0b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -49,6 +49,16 @@ __global__ void csr_scatter_to_csc_kernel( } } +// CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). +// +// Decide smem-vs-gmem for the DENSE OVR rank kernel +// (rank_sums_from_sorted_kernel). Per-block accumulator is one double per group +// plus a 32-slot warp buffer, i.e. (n_groups + 32) doubles. When that exceeds +// the per-block smem limit (~48 KB) the kernel must fall back to a +// global-memory accumulator (use_gmem=true). With a 48 KB limit this flips at +// roughly n_groups > 6112. Not dead: a kernel launched in smem mode with an +// oversized request simply fails to launch. Limit is device-queried via +// wilcoxon_max_smem_per_block(), so it auto-scales. static size_t ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(n_groups + 32) * sizeof(double); if (need <= wilcoxon_max_smem_per_block()) { @@ -61,8 +71,19 @@ static size_t ovr_smem_config(int n_groups, bool& use_gmem) { } /** - * Decide smem-vs-gmem for the sparse OVR rank kernel. Two accumulator - * arrays (grp_sums + grp_nz_count) of size n_groups each plus warp buf. + * CRITICAL — DO NOT REMOVE the gmem branch. This is the load-bearing path for + * Perturb-seq / pooled-CRISPR DE, where n_groups is in the thousands. + * + * Decide smem-vs-gmem for the sparse OVR rank kernel. The per-block accumulator + * is two double arrays of size n_groups (grp_sums + grp_nz_count) plus a + * 32-slot warp buffer, i.e. (2*n_groups + 32) doubles. When that exceeds the + * per-block shared-memory limit (~48 KB) the kernel CANNOT launch in smem mode, + * so we set use_gmem=true and rank_sums_sparse_ovr_kernel accumulates in a + * caller-provided global-memory buffer instead. With a 48 KB limit this flips + * at roughly n_groups > 3056. Reviewers/static analysis have twice mistaken + * this fallback for dead code; it is the ONLY path that works at large + * n_groups. The limit is queried per device via wilcoxon_max_smem_per_block(), + * so the threshold auto-scales with the GPU. */ static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 8dd205c8..6b3b8dbf 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -34,11 +34,8 @@ static void ovr_sparse_csc_host_streaming_impl( if (max_nnz > 0) { int max_nnz_i32 = checked_cub_items(max_nnz, "OVR host CSC sparse sub-batch nnz"); - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, max_nnz_i32, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = cub_segmented_sortpairs_temp_bytes( + max_nnz_i32, sub_batch_cols); } std::vector streams(n_streams); @@ -305,11 +302,8 @@ static void ovr_sparse_csr_host_streaming_impl( if (max_batch_nnz > 0) { int max_batch_nnz_i32 = checked_cub_items( max_batch_nnz, "OVR host CSR sparse sub-batch nnz"); - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, max_batch_nnz_i32, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = cub_segmented_sortpairs_temp_bytes(max_batch_nnz_i32, + sub_batch_cols); } int tpb = UTIL_BLOCK_SIZE; @@ -560,11 +554,8 @@ static void ovr_sparse_csc_streaming_impl( if (max_nnz > 0) { int max_nnz_i32 = checked_cub_items(max_nnz, "OVR device CSC sparse sub-batch nnz"); - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, max_nnz_i32, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(max_nnz_i32, sub_batch_cols); } std::vector streams(n_streams); @@ -744,11 +735,8 @@ static void ovr_sparse_csr_streaming_impl( if (max_batch_nnz > 0) { int max_batch_nnz_i32 = checked_cub_items( max_batch_nnz, "OVR device CSR sparse sub-batch nnz"); - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, max_batch_nnz_i32, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = cub_segmented_sortpairs_temp_bytes(max_batch_nnz_i32, + sub_batch_cols); } int n_streams = N_STREAMS; @@ -756,11 +744,18 @@ static void ovr_sparse_csr_streaming_impl( // CSR path needs 4 sort arrays per stream (scatter intermediates + // CUB output). Fit stream count to available GPU memory. + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); size_t per_stream_bytes = max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + (size_t)n_groups * sub_batch_cols * sizeof(double) + sub_batch_cols * sizeof(double); + if (rank_use_gmem) { + // gmem rank fallback (n_groups too large for smem): per-stream + // d_nz_scratch accumulator, same size as sub_rank_sums. + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } size_t free_mem = 0, total_mem = 0; cudaMemGetInfo(&free_mem, &total_mem); @@ -773,8 +768,6 @@ static void ovr_sparse_csr_streaming_impl( for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); int tpb = UTIL_BLOCK_SIZE; - bool rank_use_gmem = false; - size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); int scatter_blocks = (n_rows + tpb - 1) / tpb; struct StreamBuf { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 3fefdc99..54fc42d4 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -31,6 +31,15 @@ * * Grid: (sb_cols,) Block: (tpb,) */ +// HEADLINE sparse-OVR optimization (OVR-only). Ranks ONLY stored positive +// values; all zeros (stored + implicit n_rows-nnz) are treated as one leading +// tie block ranked analytically at (total_zero+1)/2, and each group's zero +// contribution is applied in closed form. Cost is O(nnz log nnz) per column, +// not O(n_rows log n_rows). The `use_gmem` flag selects shared- vs +// global-memory accumulators (see sparse_ovr_smem_config) -- CRITICAL: the +// use_gmem path is REQUIRED for large n_groups (Perturb-seq) and must not be +// removed. Validity relies on the upstream rejection of explicit negative +// sparse values, which guarantees zeros form the first tie block. template __global__ void rank_sums_sparse_ovr_kernel( const float* __restrict__ sorted_vals, @@ -227,6 +236,15 @@ __global__ void rank_sums_sparse_ovr_kernel( } } +// CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). +// +// Decide smem-vs-gmem for the sparse-OVR stats cast-and-accumulate kernel +// (sums / sq-sums / nnz). Needs n_arrays*n_groups doubles in smem; when that +// exceeds the per-block limit, use_gmem=true selects +// ovr_cast_and_accumulate_sparse_global_kernel, which accumulates directly in +// global memory. Same large-n_groups workloads that drive +// sparse_ovr_smem_config to gmem also drive this one; both fallbacks are +// load-bearing, not dead. static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, bool compute_nnz, bool& use_gmem) { int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); @@ -302,6 +320,11 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( } } +// CRITICAL — DO NOT REMOVE. Global-memory variant of the stats accumulator, +// selected by cast_accumulate_smem_config when n_groups is too large for the +// smem version. Required for Perturb-seq-scale n_groups; the smem kernel cannot +// launch when its (n_arrays*n_groups) double buffer exceeds the per-block +// limit. template __global__ void ovr_cast_and_accumulate_sparse_global_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 7de97082..b1cb3298 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -748,6 +748,135 @@ def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt, pre_load): ) +def _make_sized_groups_adata(group_sizes, n_genes, seed=0): + """AnnData with exact per-group sizes (drives OVO tier selection by max size).""" + rng = np.random.default_rng(seed) + n_obs = int(sum(group_sizes)) + X = np.abs(rng.standard_normal((n_obs, n_genes))).astype(np.float32) + X[X < 0.3] = 0.0 # zeros create tie groups, exercising tie correction + labels = np.concatenate( + [np.full(sz, f"g{i}", dtype=object) for i, sz in enumerate(group_sizes)] + ) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"gene_{j}" for j in range(n_genes)]) + adata = sc.AnnData(X=X, obs=obs, var=var) + adata.uns["log1p"] = {"base": None} + return adata + + +# Tier thresholds (wilcoxon_fast_common.cuh): tier0<=32, tier0_64<=64, +# tier2<=512, tier1(fused smem sort)<=2500, tier3(CUB segmented sort)>2500. +# Group sizes in the standard blobs datasets are <=~70, so tier1/tier3 are +# otherwise never exercised. These force a single large test group. +@pytest.mark.parametrize( + "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] +) +@pytest.mark.parametrize("tie_correct", [False, True]) +@pytest.mark.parametrize("big", [700, 3000], ids=["tier1_fused", "tier3_cub"]) +def test_wilcoxon_ovo_large_group_tiers_match_scanpy(fmt, tie_correct, big): + # g0 = reference, g1 = the large test group that drives tier selection. + adata_gpu = _make_sized_groups_adata([60, big, 45], n_genes=6, seed=1) + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "use_raw": False, + "reference": "g0", + "tie_correct": tie_correct, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu = adata_gpu.uns["rank_genes_groups"] + cpu = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in gpu[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu[field][group], dtype=float), + np.asarray(cpu[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +@pytest.mark.parametrize( + "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] +) +def test_wilcoxon_ovo_mixed_tier_sizes_match_scanpy(fmt): + # Groups spanning tier0 (20), tier0_64 (50) and tier2 (300) co-launched with + # tie_correct=True, pinning the skip_le boundaries and the ref_tie_sums gate. + adata_gpu = _make_sized_groups_adata([80, 20, 50, 300], n_genes=6, seed=2) + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "use_raw": False, + "reference": "g0", + "tie_correct": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu = adata_gpu.uns["rank_genes_groups"] + cpu = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in gpu[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu[field][group], dtype=float), + np.asarray(cpu[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +# n_groups > ~3056 makes the per-block smem for the sparse-OVR accumulator +# ((2*n_groups+32) doubles) exceed the 48KB static limit, so sparse_ovr_smem_config +# (and the dense ovr_smem_config) fall back to the global-memory accumulator. +# This is the perturbation regime (thousands of guides vs rest). scanpy's +# 3000+-group DataFrame build is O(n_groups^2) and too slow for an in-suite +# parity check; gmem-vs-scanpy parity is verified out-of-band (<=2e-15). Here we +# guard that every storage format (incl. the dense reference kernel) agrees at +# gmem scale, with and without tie correction. +@pytest.mark.parametrize("tie_correct", [False, True]) +def test_wilcoxon_ovr_many_groups_gmem_formats_agree(tie_correct): + adata = _make_sized_groups_adata([26] * 3100, n_genes=6, seed=3) + ref = None + for fmt in ("numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"): + a = adata.copy() + a.X = _to_format(adata.X, fmt) + rsc.tl.rank_genes_groups( + a, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=tie_correct, + n_genes=6, + ) + r = a.uns["rank_genes_groups"] + cur = { + field: np.vstack( + [np.asarray(r[field][n], dtype=float) for n in r[field].dtype.names] + ) + for field in ("scores", "pvals") + } + if ref is None: + ref = cur + continue + for field in ("scores", "pvals"): + np.testing.assert_allclose( + cur[field], ref[field], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + @pytest.mark.parametrize( ("groups", "reference"), [ From 75b810a23637de3750fef7a62589a2c08f8fb370 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 3 Jun 2026 17:45:10 +0200 Subject: [PATCH 10/36] start cleanup --- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 78 +++---- .../_cuda/wilcoxon/wilcoxon.cu | 96 ++------ .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 18 +- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 206 +++-------------- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 196 +++------------- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 214 ++++++++++++------ .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 71 ++---- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 27 +++ 8 files changed, 335 insertions(+), 571 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index a8e9ed4f..bfbc0dc2 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -155,7 +155,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( // sorted grp_col values within each thread's stride. // ============================================================================ -__global__ void batched_rank_sums_presorted_kernel( +__global__ void ovo_rank_huge_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, @@ -168,7 +168,7 @@ __global__ void batched_rank_sums_presorted_kernel( int g_end = grp_offsets[grp + 1]; int n_grp = g_end - g_start; - // Size-gated dispatch (see ovo_fused_sort_rank_kernel for the contract). + // Size-gated dispatch (see ovo_rank_large_kernel for the contract). if (n_grp <= skip_n_grp_le) return; if (n_grp == 0) { @@ -261,11 +261,11 @@ __global__ void batched_rank_sums_presorted_kernel( // ============================================================================ // Tier 1 fused kernel: smem bitonic sort + binary search rank sums // For small groups (< ~2K cells). No CUB, no global memory sort buffers. -// Grid: (n_cols, n_groups), Block: min(padded_grp_size, 512) -// Shared memory: padded_grp_size floats + 32 doubles (warp reduction) +// Grid: (n_cols, n_groups), Block: min(large_padded, 512) +// Shared memory: large_padded floats + 32 doubles (warp reduction) // ============================================================================ -__global__ void ovo_fused_sort_rank_kernel( +__global__ void ovo_rank_large_kernel( const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) // unsorted @@ -273,7 +273,7 @@ __global__ void ovo_fused_sort_rank_kernel( double* __restrict__ rank_sums, // (n_groups, n_cols) row-major double* __restrict__ tie_corr, // (n_groups, n_cols) row-major int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, - int padded_grp_size, int skip_n_grp_le /*= 0*/) { + int large_padded, int skip_n_grp_le /*= 0*/) { int col = blockIdx.x; int grp = blockIdx.y; if (col >= n_cols || grp >= n_groups) return; @@ -295,23 +295,23 @@ __global__ void ovo_fused_sort_rank_kernel( return; } - // Shared memory: [padded_grp_size floats | 32 doubles for warp reduction] + // Shared memory: [large_padded floats | 32 doubles for warp reduction] extern __shared__ char smem_raw[]; float* grp_smem = (float*)smem_raw; - double* warp_buf = (double*)(smem_raw + padded_grp_size * sizeof(float)); + double* warp_buf = (double*)(smem_raw + large_padded * sizeof(float)); // Load group data into shared memory, pad with +INF const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; for (int i = threadIdx.x; i < n_grp; i += blockDim.x) grp_smem[i] = grp_col[i]; - for (int i = n_grp + threadIdx.x; i < padded_grp_size; i += blockDim.x) + for (int i = n_grp + threadIdx.x; i < large_padded; i += blockDim.x) grp_smem[i] = __int_as_float(0x7f800000); // +INF __syncthreads(); // Bitonic sort in shared memory - for (int k = 2; k <= padded_grp_size; k <<= 1) { + for (int k = 2; k <= large_padded; k <<= 1) { for (int j = k >> 1; j > 0; j >>= 1) { - for (int i = threadIdx.x; i < padded_grp_size; i += blockDim.x) { + for (int i = threadIdx.x; i < large_padded; i += blockDim.x) { int ixj = i ^ j; if (ixj > i) { bool asc = ((i & k) == 0); @@ -439,7 +439,7 @@ __global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, if (threadIdx.x == 0) ref_tie_sums[col] = total; } -__global__ void ovo_small64_sort_rank_kernel( +__global__ void ovo_rank_small_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, const int* __restrict__ grp_offsets, const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, @@ -452,24 +452,24 @@ __global__ void ovo_small64_sort_rank_kernel( int g_start = grp_offsets[grp]; int g_end = grp_offsets[grp + 1]; int n_grp = g_end - g_start; - if (n_grp <= skip_n_grp_le || n_grp > TIER0_64_GROUP_THRESHOLD) return; + if (n_grp <= skip_n_grp_le || n_grp > OVO_SMALL_MAX) return; - __shared__ float grp_smem[TIER0_64_GROUP_THRESHOLD]; + __shared__ float grp_smem[OVO_SMALL_MAX]; __shared__ double warp_buf[WARP_REDUCE_BUF]; const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; const float POS_INF = __int_as_float(0x7f800000); - if (threadIdx.x < TIER0_64_GROUP_THRESHOLD) { + if (threadIdx.x < OVO_SMALL_MAX) { grp_smem[threadIdx.x] = (threadIdx.x < n_grp) ? grp_col[threadIdx.x] : POS_INF; } __syncthreads(); - for (int k = 2; k <= TIER0_64_GROUP_THRESHOLD; k <<= 1) { + for (int k = 2; k <= OVO_SMALL_MAX; k <<= 1) { for (int j = k >> 1; j > 0; j >>= 1) { int i = threadIdx.x; int ixj = i ^ j; - if (i < TIER0_64_GROUP_THRESHOLD && ixj > i) { + if (i < OVO_SMALL_MAX && ixj > i) { bool asc = ((i & k) == 0); float a = grp_smem[i], b = grp_smem[ixj]; if (asc ? (a > b) : (a < b)) { @@ -569,7 +569,7 @@ __global__ void ovo_small64_sort_rank_kernel( // ref_tie_sums[col] and adds only group-only / ref-overlap deltas. // ============================================================================ -__global__ void ovo_medium_unsorted_rank_kernel( +__global__ void ovo_rank_medium_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, const int* __restrict__ grp_offsets, const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, @@ -675,10 +675,9 @@ __global__ void ovo_medium_unsorted_rank_kernel( // sync is __syncwarp — no smem, no __syncthreads. // ============================================================================ -__device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, - int n_ref, float v_lane, - int n_grp, - unsigned int active_mask) { +__device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, + float v_lane, int n_grp, + unsigned int active_mask) { int lane = threadIdx.x & 31; double local_tie = 0.0; @@ -708,7 +707,7 @@ __device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, // the result. int cnt_grp = 0; #pragma unroll - for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + for (int lane_i = 0; lane_i < OVO_WARP_MAX; ++lane_i) { float vi = __shfl_sync(0xffffffff, v_lane, lane_i); if (is_first && lane_i < n_grp && vi == v) ++cnt_grp; } @@ -752,7 +751,7 @@ __device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, // group values consume the count. int cnt = 0; #pragma unroll - for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + for (int lane_i = 0; lane_i < OVO_WARP_MAX; ++lane_i) { int src_lane = (lane_i < n_grp) ? lane_i : 0; float vi = __shfl_sync(active_mask, v_lane, src_lane); if (first_in_grp && !in_ref && lane_i >= lane && lane_i < n_grp && @@ -773,9 +772,10 @@ __device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, return local_tie; // meaningful on lane 0. } -__device__ __forceinline__ double tier0_tie_delta_warp( - const float* ref_col, int n_ref, float v_lane, int n_grp, - unsigned int active_mask) { +__device__ __forceinline__ double warp_tie_delta(const float* ref_col, + int n_ref, float v_lane, + int n_grp, + unsigned int active_mask) { int lane = threadIdx.x & 31; double local_delta = 0.0; @@ -789,7 +789,7 @@ __device__ __forceinline__ double tier0_tie_delta_warp( int cnt_grp = 0; #pragma unroll - for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + for (int lane_i = 0; lane_i < OVO_WARP_MAX; ++lane_i) { int src_lane = (lane_i < n_grp) ? lane_i : 0; float vi = __shfl_sync(active_mask, v_lane, src_lane); if (lane_i < n_grp && vi == v) ++cnt_grp; @@ -850,12 +850,14 @@ __device__ __forceinline__ double tier0_tie_delta_warp( // Grid: (n_cols, ceil(n_groups / 8)), Block: 256. // ============================================================================ -__global__ void ovo_warp_sort_rank_kernel( - const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, - const int* __restrict__ grp_offsets, - const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, - double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr) { +__global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, + const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, + int n_all_grp, int n_cols, int n_groups, + bool compute_tie_corr) { constexpr int WARPS_PER_BLOCK = 8; int warp_id = threadIdx.x >> 5; int lane = threadIdx.x & 31; @@ -872,7 +874,7 @@ __global__ void ovo_warp_sort_rank_kernel( // per lane). Larger groups are delegated to Tier 1/3 in a co-launched // kernel; since each group owns its own row in rank_sums/tie_corr, the // two kernels interlace into the output without conflict. - if (n_grp > TIER0_GROUP_THRESHOLD) return; + if (n_grp > OVO_WARP_MAX) return; if (n_grp == 0) { if (lane == 0) { @@ -932,7 +934,7 @@ __global__ void ovo_warp_sort_rank_kernel( int n_eq_grp_offset = 0; // tied lanes strictly before this one int n_eq_grp_after = 1; // count self #pragma unroll - for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + for (int lane_i = 0; lane_i < OVO_WARP_MAX; ++lane_i) { if (lane_i >= n_grp) continue; float vi = __shfl_sync(active_mask, v, lane_i); if (lane_i < lane) { @@ -964,9 +966,9 @@ __global__ void ovo_warp_sort_rank_kernel( double tie_sum; if (ref_tie_sums != nullptr) { tie_sum = ref_tie_sums[col] + - tier0_tie_delta_warp(ref_col, n_ref, x, n_grp, active_mask); + warp_tie_delta(ref_col, n_ref, x, n_grp, active_mask); } else { - tie_sum = tier0_tie_sum_warp(ref_col, n_ref, x, n_grp, active_mask); + tie_sum = warp_tie_sum(ref_col, n_ref, x, n_grp, active_mask); } if (lane == 0) { int n = n_ref + n_grp; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index f0351d22..890242c2 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -142,19 +142,16 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( cudaStreamSynchronize(upstream_stream); cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost); - auto t1 = make_tier1_config(h_offsets.data(), n_groups); + auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.any_above_t2 && t1.use_tier1; - bool needs_tier3 = t1.any_above_t2 && !use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; + bool run_large = t1.above_medium && t1.run_large; + bool run_huge = t1.above_medium && !run_large; std::vector h_sort_group_ids; int n_sort_groups = n_groups; - if (needs_tier3) { - h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, - TIER2_GROUP_THRESHOLD); + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_offsets.data(), n_groups, OVO_MEDIUM_MAX); n_sort_groups = (int)h_sort_group_ids.size(); } @@ -171,7 +168,7 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( checked_cub_items(sub_grp_items, "Dense OVO group sub-batch"); size_t grp_cub_temp_bytes = 0; - if (needs_tier3) { + if (run_huge) { int max_grp_seg = checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, "Dense OVO group segment count"); @@ -195,7 +192,7 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( RmmScratchPool pool; int* d_sort_group_ids = nullptr; - if (needs_tier3) { + if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), h_sort_group_ids.size() * sizeof(int), @@ -220,17 +217,16 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); bufs[s].grp_cub_temp = - needs_tier3 ? pool.alloc(grp_cub_temp_bytes) : nullptr; + run_huge ? pool.alloc(grp_cub_temp_bytes) : nullptr; bufs[s].ref_tie_sums = - (compute_tie_corr && - (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium)) ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].sub_tie_corr = pool.alloc((size_t)n_groups * sub_batch_cols); - if (needs_tier3) { + if (run_huge) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); int max_seg = checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, @@ -270,68 +266,14 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); ref_sub = buf.ref_sorted; - int skip_le = 0; - bool run_tier0 = t1.use_tier0; - bool run_tier0_64 = t1.any_tier0_64; - bool run_tier2 = t1.any_tier2; - if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { - launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, - stream); - } - if (run_tier0) { - launch_tier0(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, - buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, - sb_cols, n_groups, compute_tie_corr, stream); - if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; - } - if (run_tier0_64) { - launch_tier0_64(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, - buf.sub_rank_sums, buf.sub_tie_corr, n_ref, - n_all_grp, sb_cols, n_groups, compute_tie_corr, - skip_le, stream); - if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { - skip_le = TIER0_64_GROUP_THRESHOLD; - } - } - if (run_tier2) { - launch_tier2_medium(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, - buf.sub_rank_sums, buf.sub_tie_corr, n_ref, - n_all_grp, sb_cols, n_groups, compute_tie_corr, - skip_le, stream); - } - - int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; - if (t1.any_above_t2 && use_tier1) { - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - ref_sub, grp_sub, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size, upper_skip_le); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else if (needs_tier3) { - int sb_grp_seg = - checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, - "Dense OVO active group segment count"); - int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_tier3_seg_begin_end_offsets_kernel<<>>( - grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, - buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); - CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); - - size_t temp = grp_cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.grp_cub_temp, temp, grp_sub, buf.grp_sorted, - sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); - - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, upper_skip_le); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } + OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.grp_cub_temp}; + ovo_dispatch_tiers(ref_sub, grp_sub, grp_offsets, t1, sc, + d_sort_group_ids, n_sort_groups, grp_cub_temp_bytes, + sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), buf.sub_rank_sums, sb_cols * sizeof(double), diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 9de92d62..d1f32335 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -33,18 +33,18 @@ constexpr int WARP_REDUCE_BUF = 32; // amortised 8× across (col, group) work items. This path is the fast // route for per-celltype perturbation-style workloads where most test // groups have only a few dozen cells. -constexpr int TIER0_GROUP_THRESHOLD = 32; +constexpr int OVO_WARP_MAX = 32; // Second small-group tier for perturbation workloads where most groups are // slightly larger than one warp. Uses one compact shared-memory sort block per // (column, group), avoiding the heavier Tier 2 in-group scan. -constexpr int TIER0_64_GROUP_THRESHOLD = 64; +constexpr int OVO_SMALL_MAX = 64; // Medium-group cutoff for the unsorted direct-rank kernel. For perturbation // workloads most groups sit below this range, where avoiding a full smem // bitonic sort wins despite the O(n^2) in-group count. -constexpr int TIER2_GROUP_THRESHOLD = 512; +constexpr int OVO_MEDIUM_MAX = 512; // Max group size for the fused smem-sort rank kernel (Tier 1 fast path). // Beyond this, fall back to CUB segmented sort + binary-search rank kernel. -constexpr int TIER1_GROUP_THRESHOLD = 2500; +constexpr int OVO_LARGE_MAX = 2500; // Per-stream dense slab budget (float32 items). Dynamic sub-batching sizes // each group's column batch so that (n_g × eff_sb_cols) ≤ this. Bigger = // fewer kernel launches; smaller = less per-stream memory. 128M items × 4B = @@ -84,11 +84,11 @@ constexpr size_t WILCOXON_FALLBACK_SMEM_PER_BLOCK = 48 * 1024; // CRITICAL device-limit query that powers every smem/gmem and tier decision. // Returns the per-block shared-memory limit (cached per device). Consumed by // ovr_smem_config, sparse_ovr_smem_config, cast_accumulate_smem_config, and -// make_tier1_config to decide when accumulators/sorts no longer fit in smem and -// must fall back to global memory or CUB. DO NOT hardcode a smem value in place -// of this call -- the gmem-fallback thresholds (e.g. sparse OVR ~3056 groups) -// auto-scale with the GPU because of it; falls back to 48 KB if the query -// fails. +// make_ovo_tier_plan to decide when accumulators/sorts no longer fit in smem +// and must fall back to global memory or CUB. DO NOT hardcode a smem value in +// place of this call -- the gmem-fallback thresholds (e.g. sparse OVR ~3056 +// groups) auto-scale with the GPU because of it; falls back to 48 KB if the +// query fails. static inline size_t wilcoxon_max_smem_per_block() { int device = 0; if (cudaGetDevice(&device) != cudaSuccess) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 1a9215a4..91622f71 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -19,18 +19,15 @@ static void ovo_streaming_csr_impl( std::vector h_offsets(n_groups + 1); cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost); - auto t1 = make_tier1_config(h_offsets.data(), n_groups); + auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.any_above_t2 && t1.use_tier1; - bool needs_tier3 = t1.any_above_t2 && !use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; + bool run_large = t1.above_medium && t1.run_large; + bool run_huge = t1.above_medium && !run_large; std::vector h_sort_group_ids; int n_sort_groups = n_groups; - if (needs_tier3) { - h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, - TIER2_GROUP_THRESHOLD); + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_offsets.data(), n_groups, OVO_MEDIUM_MAX); n_sort_groups = (int)h_sort_group_ids.size(); } @@ -63,7 +60,7 @@ static void ovo_streaming_csr_impl( RmmScratchPool pool; size_t cub_temp_bytes = 0; - if (needs_tier3) { + if (run_huge) { size_t cub_grp_bytes = 0; int sub_grp_items_i32 = checked_cub_items(sub_grp_items, "OVO device CSR group sub-batch"); @@ -81,7 +78,7 @@ static void ovo_streaming_csr_impl( cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); int* d_sort_group_ids = nullptr; - if (needs_tier3) { + if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), h_sort_group_ids.size() * sizeof(int), @@ -102,17 +99,16 @@ static void ovo_streaming_csr_impl( for (int s = 0; s < n_streams; s++) { bufs[s].grp_dense = pool.alloc(sub_grp_items); bufs[s].cub_temp = - needs_tier3 ? pool.alloc(cub_temp_bytes) : nullptr; + run_huge ? pool.alloc(cub_temp_bytes) : nullptr; bufs[s].ref_tie_sums = - (compute_tie_corr && - (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium)) ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].sub_tie_corr = pool.alloc((size_t)n_groups * sub_batch_cols); - if (needs_tier3) { + if (run_huge) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); int max_seg = checked_int_product( (size_t)n_sort_groups, (size_t)sub_batch_cols, @@ -189,78 +185,14 @@ static void ovo_streaming_csr_impl( CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); } - int skip_le = 0; - bool run_tier0 = t1.use_tier0; - bool run_tier0_64 = t1.any_tier0_64; - bool run_tier2 = t1.any_tier2; - if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { - launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, - stream); - } - if (run_tier0) { - launch_tier0(ref_sub, buf.grp_dense, grp_offsets, - buf.ref_tie_sums, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr, stream); - if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; - } - if (run_tier0_64) { - launch_tier0_64(ref_sub, buf.grp_dense, grp_offsets, - buf.ref_tie_sums, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr, skip_le, stream); - if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { - skip_le = TIER0_64_GROUP_THRESHOLD; - } - } - if (run_tier2) { - launch_tier2_medium( - ref_sub, buf.grp_dense, grp_offsets, buf.ref_tie_sums, - buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, - sb_cols, n_groups, compute_tie_corr, skip_le, stream); - } - - int upper_skip_le = - t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; - if (t1.any_above_t2 && use_tier1) { - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - ref_sub, buf.grp_dense, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size, upper_skip_le); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else if (needs_tier3) { - int sb_grp_seg = checked_int_product( - (size_t)n_sort_groups, (size_t)sb_cols, - "OVO device CSR active group segment count"); - { - int blk = - (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_tier3_seg_begin_end_offsets_kernel<<< - blk, UTIL_BLOCK_SIZE, 0, stream>>>( - grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, - buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); - CUDA_CHECK_LAST_ERROR( - build_tier3_seg_begin_end_offsets_kernel); - } - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, - sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); - } - { - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, upper_skip_le); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } - } + OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.cub_temp}; + ovo_dispatch_tiers(ref_sub, buf.grp_dense, grp_offsets, t1, sc, + d_sort_group_ids, n_sort_groups, cub_temp_bytes, + sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), buf.sub_rank_sums, sb_cols * sizeof(double), @@ -306,18 +238,15 @@ static void ovo_streaming_csc_impl( std::vector h_offsets(n_groups + 1); cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost); - auto t1 = make_tier1_config(h_offsets.data(), n_groups); + auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.any_above_t2 && t1.use_tier1; - bool needs_tier3 = t1.any_above_t2 && !use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; + bool run_large = t1.above_medium && t1.run_large; + bool run_huge = t1.above_medium && !run_large; std::vector h_sort_group_ids; int n_sort_groups = n_groups; - if (needs_tier3) { - h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, - TIER2_GROUP_THRESHOLD); + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_offsets.data(), n_groups, OVO_MEDIUM_MAX); n_sort_groups = (int)h_sort_group_ids.size(); } @@ -335,7 +264,7 @@ static void ovo_streaming_csc_impl( size_t cub_ref_bytes = cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); size_t cub_temp_bytes = cub_ref_bytes; - if (needs_tier3) { + if (run_huge) { size_t cub_grp_bytes = 0; int max_grp_seg = checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, @@ -350,7 +279,7 @@ static void ovo_streaming_csc_impl( RmmScratchPool pool; int* d_sort_group_ids = nullptr; - if (needs_tier3) { + if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), h_sort_group_ids.size() * sizeof(int), @@ -378,15 +307,14 @@ static void ovo_streaming_csc_impl( bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); bufs[s].ref_tie_sums = - (compute_tie_corr && - (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium)) ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].sub_tie_corr = pool.alloc((size_t)n_groups * sub_batch_cols); - if (needs_tier3) { + if (run_huge) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); int max_grp_seg = checked_int_product( (size_t)n_sort_groups, (size_t)sub_batch_cols, @@ -439,74 +367,14 @@ static void ovo_streaming_csc_impl( n_all_grp, col); CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); - int skip_le = 0; - bool run_tier0 = t1.use_tier0; - bool run_tier0_64 = t1.any_tier0_64; - bool run_tier2 = t1.any_tier2; - if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { - launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, - sb_cols, stream); - } - if (run_tier0) { - launch_tier0(buf.ref_sorted, buf.grp_dense, grp_offsets, - buf.ref_tie_sums, buf.sub_rank_sums, buf.sub_tie_corr, - n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, - stream); - if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; - } - if (run_tier0_64) { - launch_tier0_64(buf.ref_sorted, buf.grp_dense, grp_offsets, - buf.ref_tie_sums, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr, skip_le, stream); - if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { - skip_le = TIER0_64_GROUP_THRESHOLD; - } - } - if (run_tier2) { - launch_tier2_medium(buf.ref_sorted, buf.grp_dense, grp_offsets, - buf.ref_tie_sums, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr, skip_le, stream); - } - - int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; - if (t1.any_above_t2 && use_tier1) { - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size, upper_skip_le); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else if (needs_tier3) { - int sb_grp_seg = checked_int_product( - (size_t)n_sort_groups, (size_t)sb_cols, - "OVO device CSC active group segment count"); - { - int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_tier3_seg_begin_end_offsets_kernel<<>>( - grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, - buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); - CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); - } - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, - sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); - } - { - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - buf.ref_sorted, buf.grp_sorted, grp_offsets, - buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, - sb_cols, n_groups, compute_tie_corr, upper_skip_le); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } - } + OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.cub_temp}; + ovo_dispatch_tiers(buf.ref_sorted, buf.grp_dense, grp_offsets, t1, sc, + d_sort_group_ids, n_sort_groups, cub_temp_bytes, + sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), buf.sub_rank_sums, sb_cols * sizeof(double), diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index fc91821b..deb2f395 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -19,18 +19,15 @@ static void ovo_streaming_csc_host_impl( if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch from host offsets ---- - auto t1 = make_tier1_config(h_grp_offsets, n_groups); + auto t1 = make_ovo_tier_plan(h_grp_offsets, n_groups); int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.any_above_t2 && t1.use_tier1; - bool needs_tier3 = t1.any_above_t2 && !use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; + bool run_large = t1.above_medium && t1.run_large; + bool run_huge = t1.above_medium && !run_large; std::vector h_sort_group_ids; int n_sort_groups = n_groups; - if (needs_tier3) { + if (run_huge) { h_sort_group_ids = - make_sort_group_ids(h_grp_offsets, n_groups, TIER2_GROUP_THRESHOLD); + make_sort_group_ids(h_grp_offsets, n_groups, OVO_MEDIUM_MAX); n_sort_groups = (int)h_sort_group_ids.size(); } @@ -49,7 +46,7 @@ static void ovo_streaming_csc_host_impl( size_t cub_ref_bytes = cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); size_t cub_temp_bytes = cub_ref_bytes; - if (needs_tier3) { + if (run_huge) { int max_grp_seg = checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, "OVO host CSC group segment count"); @@ -103,7 +100,7 @@ static void ovo_streaming_csc_host_impl( cudaMemcpyHostToDevice); cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), cudaMemcpyHostToDevice); - if (needs_tier3) { + if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), h_sort_group_ids.size() * sizeof(int), @@ -142,8 +139,7 @@ static void ovo_streaming_csc_host_impl( bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); bufs[s].ref_tie_sums = - (compute_tie_corr && - (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium)) ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].d_rank_sums = @@ -156,7 +152,7 @@ static void ovo_streaming_csc_host_impl( compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); bufs[s].d_group_nnz = pool.alloc( compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); - if (needs_tier3) { + if (run_huge) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); int max_grp_seg = checked_int_product( (size_t)n_sort_groups, (size_t)sub_batch_cols, @@ -242,74 +238,14 @@ static void ovo_streaming_csc_host_impl( CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); // ---- Tier dispatch: sort grp + rank ---- - int skip_le = 0; - bool run_tier0 = t1.use_tier0; - bool run_tier0_64 = t1.any_tier0_64; - bool run_tier2 = t1.any_tier2; - if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { - launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, - sb_cols, stream); - } - if (run_tier0) { - launch_tier0(buf.ref_sorted, buf.grp_dense, d_grp_offsets, - buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, - n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, - stream); - if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; - } - if (run_tier0_64) { - launch_tier0_64(buf.ref_sorted, buf.grp_dense, d_grp_offsets, - buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, - n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, skip_le, stream); - if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { - skip_le = TIER0_64_GROUP_THRESHOLD; - } - } - if (run_tier2) { - launch_tier2_medium(buf.ref_sorted, buf.grp_dense, d_grp_offsets, - buf.ref_tie_sums, buf.d_rank_sums, - buf.d_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr, skip_le, stream); - } - - int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; - if (t1.any_above_t2 && use_tier1) { - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, - buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size, upper_skip_le); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else if (needs_tier3) { - int sb_grp_seg = - checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, - "OVO host CSC active group segment count"); - { - int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_tier3_seg_begin_end_offsets_kernel<<>>( - d_grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, - buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); - CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); - } - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, - sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); - } - { - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - buf.ref_sorted, buf.grp_sorted, d_grp_offsets, - buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr, upper_skip_le); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } - } + OvoTierScratch sc{buf.ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.cub_temp}; + ovo_dispatch_tiers(buf.ref_sorted, buf.grp_dense, d_grp_offsets, t1, sc, + d_sort_group_ids, n_sort_groups, cub_temp_bytes, + sb_grp_actual, tpb_rank, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, stream); // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), @@ -632,8 +568,8 @@ static void ovo_streaming_csr_host_impl( cudaStreamDestroy(ref_stream); // ---- Phase 2: Per-pack streaming ---- - auto t1 = make_tier1_config(h_grp_offsets, n_test); - bool may_need_cub = (t1.max_grp_size > TIER1_GROUP_THRESHOLD); + auto t1 = make_ovo_tier_plan(h_grp_offsets, n_test); + bool may_need_cub = (t1.max_grp_size > OVO_LARGE_MAX); constexpr int MAX_GROUP_STREAMS = 4; int n_streams = MAX_GROUP_STREAMS; @@ -707,17 +643,17 @@ static void ovo_streaming_csr_host_impl( const Pack& pack = packs[p]; int K = pack.end - pack.first; if (K == 0 || pack.n_rows == 0) continue; - Tier1Config pack_t1 = make_tier1_config(h_grp_offsets + pack.first, K); + OvoTierPlan pack_t1 = make_ovo_tier_plan(h_grp_offsets + pack.first, K); int pack_tpb_rank = round_up_to_warp( std::min(pack_t1.max_grp_size, MAX_THREADS_PER_BLOCK)); - bool pack_has_above_t2 = pack_t1.max_grp_size > TIER2_GROUP_THRESHOLD; - int pack_tier3_skip_le = - pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : TIER0_GROUP_THRESHOLD; + bool pack_has_above_t2 = pack_t1.max_grp_size > OVO_MEDIUM_MAX; + int pack_huge_skip_le = + pack_has_above_t2 ? OVO_MEDIUM_MAX : OVO_WARP_MAX; std::vector h_sort_group_ids; int pack_n_sort_groups = K; - if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + if (pack_t1.above_warp && !pack_t1.run_large) { h_sort_group_ids = make_sort_group_ids(h_grp_offsets + pack.first, - K, pack_tier3_skip_le); + K, pack_huge_skip_le); pack_n_sort_groups = (int)h_sort_group_ids.size(); } @@ -725,7 +661,7 @@ static void ovo_streaming_csr_host_impl( cudaStream_t stream = streams[s]; auto& buf = bufs[s]; - if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + if (pack_t1.above_warp && !pack_t1.run_large) { cudaMemcpyAsync(buf.d_sort_group_ids, h_sort_group_ids.data(), h_sort_group_ids.size() * sizeof(int), cudaMemcpyHostToDevice, stream); @@ -798,77 +734,15 @@ static void ovo_streaming_csr_host_impl( const float* ref_sub = d_ref_sorted + (size_t)col * n_ref; - int skip_le = 0; - bool run_tier0 = pack_t1.use_tier0; - bool run_tier0_64 = pack_t1.any_tier0_64; - bool run_tier2 = pack_t1.any_tier2; - if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { - launch_ref_tie_sums(ref_sub, buf.d_ref_tie_sums, n_ref, sb_cols, - stream); - } - if (run_tier0) { - launch_tier0(ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, - buf.d_ref_tie_sums, buf.d_rank_sums, - buf.d_tie_corr, n_ref, pack_rows, sb_cols, K, - compute_tie_corr, stream); - if (pack_t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; - } - if (run_tier0_64) { - launch_tier0_64( - ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, - buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, - pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); - if (pack_t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { - skip_le = TIER0_64_GROUP_THRESHOLD; - } - } - if (run_tier2) { - launch_tier2_medium( - ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, - buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, - pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); - } - - int upper_skip_le = - pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; - if (pack_has_above_t2 && pack_t1.use_tier1) { - dim3 grid(sb_cols, K); - ovo_fused_sort_rank_kernel<<>>( - ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, - buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, - K, compute_tie_corr, pack_t1.padded_grp_size, - upper_skip_le); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else if (pack_has_above_t2) { - int n_seg = checked_int_product( - (size_t)pack_n_sort_groups, (size_t)sb_cols, - "OVO host CSR active group segment count"); - { - int blk = (n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_tier3_seg_begin_end_offsets_kernel<<< - blk, UTIL_BLOCK_SIZE, 0, stream>>>( - buf.d_pack_grp_offsets, buf.d_sort_group_ids, - buf.d_grp_seg_offsets, buf.d_grp_seg_ends, pack_rows, - pack_n_sort_groups, sb_cols); - CUDA_CHECK_LAST_ERROR( - build_tier3_seg_begin_end_offsets_kernel); - } - { - size_t temp = cub_grp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.d_grp_dense, buf.d_grp_sorted, - sb_items, n_seg, buf.d_grp_seg_offsets, - buf.d_grp_seg_ends, BEGIN_BIT, END_BIT, stream); - } - dim3 grid(sb_cols, K); - batched_rank_sums_presorted_kernel<<>>( - ref_sub, buf.d_grp_sorted, buf.d_pack_grp_offsets, - buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, - K, compute_tie_corr, upper_skip_le); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } + OvoTierScratch sc{buf.d_ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, buf.d_grp_sorted, + buf.d_grp_seg_offsets, buf.d_grp_seg_ends, + buf.cub_temp}; + ovo_dispatch_tiers(ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + pack_t1, sc, buf.d_sort_group_ids, + pack_n_sort_groups, cub_grp_bytes, sb_items, + pack_tpb_rank, n_ref, pack_rows, sb_cols, K, + compute_tie_corr, stream); cudaMemcpy2DAsync(d_rank_sums + (size_t)pack.first * n_cols + col, n_cols * sizeof(double), buf.d_rank_sums, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index c93ceccf..b9bad9cf 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -1,12 +1,14 @@ #pragma once +#include + /** - * Build CUB segmented-sort ranges only for groups that Tier 3 will rank. + * Build CUB segmented-sort ranges only for groups in the HUGE band. * Group ids are relative to grp_offsets, and ranges still point into the * original dense group layout so the presorted rank kernel can read from the * normal per-group positions. */ -__global__ void build_tier3_seg_begin_end_offsets_kernel( +__global__ void build_huge_seg_offsets_kernel( const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, int* __restrict__ begins, int* __restrict__ ends, int n_all_grp, int n_sort_groups, int sb_cols) { @@ -56,74 +58,69 @@ __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, * kernel. This struct bundles the sizing knobs derived from the host-side * group offsets so each streaming impl can drop a 15-line prep block. */ -struct Tier1Config { +struct OvoTierPlan { int max_grp_size = 0; int min_grp_size = 0; - bool use_tier0 = - false; // any group fits in one warp (≤ TIER0_GROUP_THRESHOLD) - bool use_tier1 = + bool run_warp = false; // any group fits in one warp (≤ OVO_WARP_MAX) + bool run_large = false; // any group needs > tier0 but fits in tier1 smem sort - bool any_above_t0 = - false; // at least one group exceeds TIER0_GROUP_THRESHOLD - bool any_tier0_64 = false; // any group needs Tier 0.5: (T0, T0_64] - bool any_tier2 = false; // any group needs Tier 2: (T0_64, T2] - bool any_above_t2 = - false; // at least one group exceeds TIER2_GROUP_THRESHOLD - int padded_grp_size = 0; - int tier1_tpb = 0; - size_t tier1_smem = 0; + bool above_warp = false; // at least one group exceeds OVO_WARP_MAX + bool run_small = false; // any group needs Tier 0.5: (T0, T0_64] + bool run_medium = false; // any group needs Tier 2: (T0_64, T2] + bool above_medium = false; // at least one group exceeds OVO_MEDIUM_MAX + int large_padded = 0; + int large_tpb = 0; + size_t large_smem = 0; }; -// SINGLE source of truth for OVO tier dispatch, used by EVERY OVO path: -// dense (wilcoxon.cu) AND all four sparse OVO impls (host/device CSC/CSR), -// which extract ref+group rows to dense per sub-batch and then call this. -// Scans per-group sizes once and returns which size-gated tiers to co-launch: -// Tier 0 (<=32, TIER0_GROUP_THRESHOLD): ovo_warp_sort_rank_kernel -// Tier 0.5 (<=64, TIER0_64_GROUP_THRESHOLD): ovo_small64_sort_rank_kernel -// Tier 2 (<=512, TIER2_GROUP_THRESHOLD): ovo_medium_unsorted_rank_kernel -// Tier 1 (<=2500, TIER1_GROUP_THRESHOLD): ovo_fused_sort_rank_kernel (smem -// sort) Tier 3 (>2500): CUB segmented sort + -// batched_rank_sums_presorted_kernel -// Tiers cooperate via skip_n_grp_le: a larger tier skips groups a smaller tier -// already handled. Tier 1 is device-adapted: if its fused-sort smem footprint -// would exceed wilcoxon_max_smem_per_block() it is disabled in favor of Tier 3. -static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { - Tier1Config c; +// Single source of truth for OVO tier dispatch (used by the dense path AND all +// four sparse OVO impls, which extract ref+group rows to dense then call this). +// Scans group sizes once; returns which size bands to co-launch (by max group): +// WARP (<=32): ovo_rank_warp_kernel (warp-shuffle sort, in registers) +// SMALL (<=64): ovo_rank_small_kernel (fixed 64-element smem sort) +// MEDIUM (<=512): ovo_rank_medium_kernel (no sort; O(n^2) in-group count) +// LARGE (<=2500): ovo_rank_large_kernel (fused smem bitonic sort) +// HUGE (>2500): CUB segmented sort + ovo_rank_huge_kernel (presorted rank) +// Bands cooperate via skip_n_grp_le (a larger band skips groups a smaller one +// already handled). LARGE is device-adapted: if its smem would exceed the +// per-block limit it falls back to HUGE. +static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { + OvoTierPlan c; c.min_grp_size = INT_MAX; for (int g = 0; g < n_groups; g++) { int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; if (sz > c.max_grp_size) c.max_grp_size = sz; if (sz < c.min_grp_size) c.min_grp_size = sz; - if (sz > TIER0_GROUP_THRESHOLD && sz <= TIER0_64_GROUP_THRESHOLD) { - c.any_tier0_64 = true; + if (sz > OVO_WARP_MAX && sz <= OVO_SMALL_MAX) { + c.run_small = true; } - if (sz > TIER0_64_GROUP_THRESHOLD && sz <= TIER2_GROUP_THRESHOLD) { - c.any_tier2 = true; + if (sz > OVO_SMALL_MAX && sz <= OVO_MEDIUM_MAX) { + c.run_medium = true; } - if (sz > TIER2_GROUP_THRESHOLD) c.any_above_t2 = true; + if (sz > OVO_MEDIUM_MAX) c.above_medium = true; } if (n_groups == 0) c.min_grp_size = 0; - // use_tier0: Tier 0 kernel is worth running (at least one group small + // run_warp: Tier 0 kernel is worth running (at least one group small // enough to benefit from the warp path). - c.use_tier0 = (c.min_grp_size <= TIER0_GROUP_THRESHOLD); - // any_above_t0: at least one group needs a non-Tier-0 kernel. - c.any_above_t0 = (c.max_grp_size > TIER0_GROUP_THRESHOLD); - // use_tier1: the fused smem-sort fast path (for groups > T0 but ≤ T1). - c.use_tier1 = c.any_above_t0 && (c.max_grp_size <= TIER1_GROUP_THRESHOLD); - if (c.use_tier1) { - c.padded_grp_size = 1; - while (c.padded_grp_size < c.max_grp_size) c.padded_grp_size <<= 1; - c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); - c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + + c.run_warp = (c.min_grp_size <= OVO_WARP_MAX); + // above_warp: at least one group needs a non-Tier-0 kernel. + c.above_warp = (c.max_grp_size > OVO_WARP_MAX); + // run_large: the fused smem-sort fast path (for groups > T0 but ≤ T1). + c.run_large = c.above_warp && (c.max_grp_size <= OVO_LARGE_MAX); + if (c.run_large) { + c.large_padded = 1; + while (c.large_padded < c.max_grp_size) c.large_padded <<= 1; + c.large_tpb = std::min(c.large_padded, MAX_THREADS_PER_BLOCK); + c.large_smem = (size_t)c.large_padded * sizeof(float) + WARP_REDUCE_BUF * sizeof(double); // Adapt to the device: if the fused-sort buffer would exceed the // per-block shared-memory limit, fall back to the tier-3 CUB segmented // sort (which has no smem cap) rather than launching a kernel that // would fail. Never triggers at the current threshold (~16.6KB), but // keeps the dispatch correct if the threshold or device limit changes. - if (c.tier1_smem > wilcoxon_max_smem_per_block()) { - c.use_tier1 = false; + if (c.large_smem > wilcoxon_max_smem_per_block()) { + c.run_large = false; } } return c; @@ -143,18 +140,19 @@ static std::vector make_sort_group_ids(const int* h_grp_offsets, // Tier 0 kernel launcher: 8 warps × 32 threads per block, one (col, group) // pair per warp. grid.y covers ceil(K/8) pair rows. -static inline void launch_tier0(const float* ref_sorted, const float* grp_dense, - const int* grp_offsets, - const double* ref_tie_sums, double* rank_sums, - double* tie_corr, int n_ref, int n_all_grp, - int sb_cols, int K, bool compute_tie_corr, - cudaStream_t stream) { +static inline void launch_ovo_warp(const float* ref_sorted, + const float* grp_dense, + const int* grp_offsets, + const double* ref_tie_sums, + double* rank_sums, double* tie_corr, + int n_ref, int n_all_grp, int sb_cols, int K, + bool compute_tie_corr, cudaStream_t stream) { constexpr int WARPS_PER_BLOCK = 8; dim3 grid(sb_cols, (K + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK); - ovo_warp_sort_rank_kernel<<>>( + ovo_rank_warp_kernel<<>>( ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, n_ref, n_all_grp, sb_cols, K, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(ovo_warp_sort_rank_kernel); + CUDA_CHECK_LAST_ERROR(ovo_rank_warp_kernel); } static inline void launch_ref_tie_sums(const float* ref_sorted, @@ -165,30 +163,118 @@ static inline void launch_ref_tie_sums(const float* ref_sorted, CUDA_CHECK_LAST_ERROR(ref_tie_sum_kernel); } -static inline void launch_tier0_64( +static inline void launch_ovo_small( const float* ref_sorted, const float* grp_dense, const int* grp_offsets, const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, cudaStream_t stream) { dim3 grid(sb_cols, K); - ovo_small64_sort_rank_kernel<<>>( + ovo_rank_small_kernel<<>>( ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le); - CUDA_CHECK_LAST_ERROR(ovo_small64_sort_rank_kernel); + CUDA_CHECK_LAST_ERROR(ovo_rank_small_kernel); } -static inline void launch_tier2_medium( +static inline void launch_ovo_medium( const float* ref_sorted, const float* grp_dense, const int* grp_offsets, const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, cudaStream_t stream) { constexpr int tpb = 256; - size_t smem = (size_t)TIER2_GROUP_THRESHOLD * sizeof(float) + + size_t smem = (size_t)OVO_MEDIUM_MAX * sizeof(float) + WARP_REDUCE_BUF * sizeof(double); dim3 grid(sb_cols, K); - ovo_medium_unsorted_rank_kernel<<>>( + ovo_rank_medium_kernel<<>>( ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le, - TIER2_GROUP_THRESHOLD); - CUDA_CHECK_LAST_ERROR(ovo_medium_unsorted_rank_kernel); + OVO_MEDIUM_MAX); + CUDA_CHECK_LAST_ERROR(ovo_rank_medium_kernel); +} + +// Per-stream scratch consumed by ovo_dispatch_tiers (one set per CUDA stream). +// grp_sorted/grp_seg_*/grp_cub_temp are only needed for the HUGE band and may +// be null otherwise. +struct OvoTierScratch { + double* ref_tie_sums; // [sb_cols] pre-computed reference tie sums, or null + double* sub_rank_sums; // [n_groups * sb_cols] rank-sum output accumulator + double* sub_tie_corr; // [n_groups * sb_cols] tie-correction output + float* grp_sorted; // HUGE: [n_all_grp * sb_cols] sorted group values + int* grp_seg_offsets; // HUGE: CUB segment begins + int* grp_seg_ends; // HUGE: CUB segment ends + uint8_t* grp_cub_temp; // HUGE: CUB scratch +}; + +// SINGLE OVO ranking engine, shared by the dense path and all four sparse OVO +// impls (host/device CSC/CSR). Given an already-sorted reference slice and a +// dense group slice for one column sub-batch, it runs the size-banded dispatch +// from `plan` (see make_ovo_tier_plan): co-launch WARP/SMALL/MEDIUM for small +// groups, then LARGE (fused smem sort) OR HUGE (CUB segmented sort) for the +// rest. Pure host-side code motion: the kernel launches are identical to the +// previous inline copies, so results and performance are unchanged. The five +// callers differ only in how they produce ref_sorted / grp_dense. +static inline void ovo_dispatch_tiers( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const OvoTierPlan& plan, const OvoTierScratch& sc, + const int* d_sort_group_ids, int n_sort_groups, size_t grp_cub_temp_bytes, + int sb_grp_items_actual, int tpb_rank, int n_ref, int n_all_grp, + int sb_cols, int n_groups, bool compute_tie_corr, cudaStream_t stream) { + bool run_large = plan.above_medium && plan.run_large; + bool run_huge = plan.above_medium && !run_large; + + int skip_le = 0; + if (compute_tie_corr && + (plan.run_warp || plan.run_small || plan.run_medium)) { + launch_ref_tie_sums(ref_sorted, sc.ref_tie_sums, n_ref, sb_cols, + stream); + } + if (plan.run_warp) { + launch_ovo_warp(ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (plan.above_warp) skip_le = OVO_WARP_MAX; + } + if (plan.run_small) { + launch_ovo_small(ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, skip_le, stream); + if (plan.max_grp_size > OVO_SMALL_MAX) skip_le = OVO_SMALL_MAX; + } + if (plan.run_medium) { + launch_ovo_medium(ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = plan.above_medium ? OVO_MEDIUM_MAX : skip_le; + if (plan.above_medium && run_large) { + dim3 grid(sb_cols, n_groups); + ovo_rank_large_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, sc.sub_rank_sums, + sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, plan.large_padded, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_rank_large_kernel); + } else if (run_huge) { + int sb_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, + "OVO active group segment count"); + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_huge_seg_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, sc.grp_seg_offsets, sc.grp_seg_ends, + n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_huge_seg_offsets_kernel); + + size_t temp = grp_cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + sc.grp_cub_temp, temp, grp_dense, sc.grp_sorted, + sb_grp_items_actual, sb_grp_seg, sc.grp_seg_offsets, + sc.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + + dim3 grid(sb_cols, n_groups); + ovo_rank_huge_kernel<<>>( + ref_sorted, sc.grp_sorted, grp_offsets, sc.sub_rank_sums, + sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_rank_huge_kernel); + } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 6b3b8dbf..e683fadf 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -182,21 +182,11 @@ static void ovr_sparse_csc_host_streaming_impl( } // Sparse rank kernel (stats already captured above) - if (rank_use_gmem) { - cudaMemsetAsync(buf.d_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - cudaMemsetAsync(buf.d_nz_scratch, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_sparse_ovr_kernel - <<>>( - buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, - d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, - buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, - rank_use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + launch_ovr_sparse_rank( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, + rank_use_gmem, stream); // D2D: scatter sub-batch results into caller's GPU buffers cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), @@ -465,20 +455,11 @@ static void ovr_sparse_csr_host_streaming_impl( stream); } - if (rank_use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - cudaMemsetAsync(buf.d_nz_scratch, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_sparse_ovr_kernel<<>>( + launch_ovr_sparse_rank( buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, - buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, - rank_use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + buf.d_nz_scratch, n_rows, sb_cols, n_groups, tpb, smem_bytes, + compute_tie_corr, rank_use_gmem, stream); cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), buf.sub_rank_sums, sb_cols * sizeof(double), @@ -626,19 +607,11 @@ static void ovr_sparse_csc_streaming_impl( } // Sparse rank kernel (handles implicit zeros analytically) - if (rank_use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - cudaMemsetAsync(buf.d_nz_scratch, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_sparse_ovr_kernel<<>>( - buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, - group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, - n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + launch_ovr_sparse_rank(buf.keys_out, buf.vals_out, buf.seg_offsets, + group_codes, group_sizes, buf.sub_rank_sums, + buf.sub_tie_corr, buf.d_nz_scratch, n_rows, + sb_cols, n_groups, tpb, smem_bytes, + compute_tie_corr, rank_use_gmem, stream); // Scatter results to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), @@ -837,19 +810,11 @@ static void ovr_sparse_csr_streaming_impl( } // Sparse rank kernel (handles implicit zeros analytically) - if (rank_use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - cudaMemsetAsync(buf.d_nz_scratch, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_sparse_ovr_kernel<<>>( - buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, - group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, - n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + launch_ovr_sparse_rank(buf.keys_out, buf.vals_out, buf.col_offsets, + group_codes, group_sizes, buf.sub_rank_sums, + buf.sub_tie_corr, buf.d_nz_scratch, n_rows, + sb_cols, n_groups, tpb, smem_bytes, + compute_tie_corr, rank_use_gmem, stream); // Scatter results to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 54fc42d4..04c1ce63 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -236,6 +236,33 @@ __global__ void rank_sums_sparse_ovr_kernel( } } +// Shared sparse-OVR rank launch, used by all four sparse OVR impls (they differ +// only in how they produce the sorted nonzeros and how they scatter results). +// Optionally zeroes the global-memory accumulators, then launches the +// analytic-zero rank kernel. use_gmem is the CRITICAL large-n_groups / +// perturbation fallback (see sparse_ovr_smem_config) — DO NOT drop the gmem +// branch. ValT is the sorted-row-index type (int everywhere today). +template +static inline void launch_ovr_sparse_rank( + const float* sorted_vals, const ValT* sorted_row_idx, + const int* col_seg_offsets, const int* group_codes, + const double* group_sizes, double* rank_sums, double* tie_corr, + double* nz_count_scratch, int n_rows, int sb_cols, int n_groups, int tpb, + size_t smem_bytes, bool compute_tie_corr, bool use_gmem, + cudaStream_t stream) { + if (use_gmem) { + cudaMemsetAsync(rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream); + cudaMemsetAsync(nz_count_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream); + } + rank_sums_sparse_ovr_kernel<<>>( + sorted_vals, sorted_row_idx, col_seg_offsets, group_codes, group_sizes, + rank_sums, tie_corr, nz_count_scratch, n_rows, sb_cols, n_groups, + compute_tie_corr, use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); +} + // CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). // // Decide smem-vs-gmem for the sparse-OVR stats cast-and-accumulate kernel From f8c00d8ef921fa0977338dd3e9734fc7da530965 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 3 Jun 2026 20:16:32 +0200 Subject: [PATCH 11/36] first draft --- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 20 ++--- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 14 +-- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 27 +++--- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 8 +- tests/test_rank_genes_groups_wilcoxon.py | 87 +++++++++++++++++++ 5 files changed, 125 insertions(+), 31 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index bfbc0dc2..2ae77947 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -259,8 +259,8 @@ __global__ void ovo_rank_huge_kernel( } // ============================================================================ -// Tier 1 fused kernel: smem bitonic sort + binary search rank sums -// For small groups (< ~2K cells). No CUB, no global memory sort buffers. +// LARGE-band fused kernel: smem bitonic sort + binary search rank sums +// For groups up to OVO_LARGE_MAX cells. No CUB, no global memory sort buffers. // Grid: (n_cols, n_groups), Block: min(large_padded, 512) // Shared memory: large_padded floats + 32 doubles (warp reduction) // ============================================================================ @@ -282,7 +282,7 @@ __global__ void ovo_rank_large_kernel( int g_end = grp_offsets[grp + 1]; int n_grp = g_end - g_start; - // Size-gated dispatch: when co-launched with the Tier 0 warp kernel we + // Size-gated dispatch: when co-launched with the WARP kernel we // skip groups it's already handling. Each group owns its own // rank_sums row, so the two kernels' writes never alias. if (n_grp <= skip_n_grp_le) return; @@ -402,7 +402,7 @@ __global__ void ovo_rank_large_kernel( } // ============================================================================ -// Tier 2 helper: tie contribution of the sorted reference alone. +// MEDIUM-band helper: tie contribution of the sorted reference alone. // One block per column. The medium unsorted-rank kernel uses this as a base // and only adds group-only/overlap deltas from the unsorted group values. // ============================================================================ @@ -561,7 +561,7 @@ __global__ void ovo_rank_small_kernel( } // ============================================================================ -// Tier 2 fused kernel: no-sort direct rank for medium groups. +// MEDIUM-band fused kernel: no-sort direct rank for medium groups. // // Avoids the smem bitonic sort for groups in (skip_n_grp_le, // max_n_grp_le]. Ranks are computed from ref binary searches plus an @@ -667,7 +667,7 @@ __global__ void ovo_rank_medium_kernel( } // ============================================================================ -// Warp-scoped tie correction for Tier 0. +// Warp-scoped tie correction for the WARP band. // // Sorted values live in a 32-lane register (one per lane, with unused lanes // carrying +INF). Walks unique values via lane-step differentials and @@ -833,7 +833,7 @@ __device__ __forceinline__ double warp_tie_delta(const float* ref_col, } // ============================================================================ -// Tier 0 fused kernel: warp-per-(col, group) pair, 8 warps packed per block. +// WARP-band kernel: warp-per-(col, group) pair, 8 warps packed per block. // // Each warp independently: // 1. Loads ≤ 32 group values into a single register (one per lane, @@ -844,7 +844,7 @@ __device__ __forceinline__ double warp_tie_delta(const float* ref_col, // 4. Warp-shuffle reduces to lane 0 and writes rank_sums / tie_corr. // // 8 (col, group) pairs per block cuts block count 8× vs the block-per-pair -// Tier 1, and the lack of __syncthreads / smem sort lets each warp run +// LARGE band, and the lack of __syncthreads / smem sort lets each warp run // independently at full throughput. // // Grid: (n_cols, ceil(n_groups / 8)), Block: 256. @@ -871,7 +871,7 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, int n_grp = g_end - g_start; // This kernel only handles groups that fit in a single warp (one value - // per lane). Larger groups are delegated to Tier 1/3 in a co-launched + // per lane). Larger groups are delegated to LARGE/HUGE in a co-launched // kernel; since each group owns its own row in rank_sums/tie_corr, the // two kernels interlace into the output without conflict. if (n_grp > OVO_WARP_MAX) return; @@ -949,7 +949,7 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, int n_eq_grp_total = n_eq_grp_offset + n_eq_grp_after; // Contribution: rank = n_lt_ref + n_lt_grp + (n_eq_ref + // n_eq_grp_total + 1) / 2, but we sum per lane so each tie lane - // gets the same mid-rank. This matches the Tier 1 accumulation. + // gets the same mid-rank. This matches the LARGE-band accumulation. local_sum = (double)(n_lt_ref + n_lt_grp) + ((double)(n_eq_ref + n_eq_grp_total) + 1.0) / 2.0; } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index d1f32335..11fa4dbc 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -26,24 +26,24 @@ constexpr int END_BIT = 32; constexpr int UTIL_BLOCK_SIZE = 256; // Scratch slots for warp-level reduction (one slot per warp, 32 warps max). constexpr int WARP_REDUCE_BUF = 32; -// Max group size for the super-fast "warp-per-(col,group)" fused kernel -// (Tier 0). Each warp sorts and ranks one (col, group) pair entirely in +// Max group size for the super-fast "warp-per-(col,group)" fused kernel (the +// WARP band). Each warp sorts and ranks one (col, group) pair entirely in // registers via warp-shuffle bitonic sort — no smem sort buffer, no // __syncthreads(). Blocks pack 8 warps so block launch overhead is // amortised 8× across (col, group) work items. This path is the fast // route for per-celltype perturbation-style workloads where most test // groups have only a few dozen cells. constexpr int OVO_WARP_MAX = 32; -// Second small-group tier for perturbation workloads where most groups are -// slightly larger than one warp. Uses one compact shared-memory sort block per -// (column, group), avoiding the heavier Tier 2 in-group scan. +// SMALL band for perturbation workloads where most groups are slightly larger +// than one warp. Uses one compact shared-memory sort block per (column, +// group), avoiding the heavier MEDIUM-band in-group scan. constexpr int OVO_SMALL_MAX = 64; // Medium-group cutoff for the unsorted direct-rank kernel. For perturbation // workloads most groups sit below this range, where avoiding a full smem // bitonic sort wins despite the O(n^2) in-group count. constexpr int OVO_MEDIUM_MAX = 512; -// Max group size for the fused smem-sort rank kernel (Tier 1 fast path). -// Beyond this, fall back to CUB segmented sort + binary-search rank kernel. +// Max group size for the fused smem-sort rank kernel (the LARGE band). +// Beyond this, fall back to the HUGE band: CUB segmented sort + rank kernel. constexpr int OVO_LARGE_MAX = 2500; // Per-stream dense slab budget (float32 items). Dynamic sub-batching sizes // each group's column batch so that (n_g × eff_sb_cols) ≤ this. Bigger = diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index b9bad9cf..9f823aba 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -52,21 +52,22 @@ __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, } /** - * Tier 1 dispatch: when the largest group fits in shared memory, a fused + * LARGE-band dispatch: when the largest group fits in shared memory, a fused * bitonic-sort + binary-search kernel handles the whole group per block. - * Otherwise we fall back to CUB segmented sort plus the pre-sorted rank - * kernel. This struct bundles the sizing knobs derived from the host-side - * group offsets so each streaming impl can drop a 15-line prep block. + * Otherwise we fall back to the HUGE band (CUB segmented sort plus the + * pre-sorted rank kernel). This struct bundles the sizing knobs derived from + * the host-side group offsets so each streaming impl can drop a 15-line prep + * block. */ struct OvoTierPlan { int max_grp_size = 0; int min_grp_size = 0; bool run_warp = false; // any group fits in one warp (≤ OVO_WARP_MAX) bool run_large = - false; // any group needs > tier0 but fits in tier1 smem sort + false; // any group needs > WARP but fits the LARGE smem-sort band bool above_warp = false; // at least one group exceeds OVO_WARP_MAX - bool run_small = false; // any group needs Tier 0.5: (T0, T0_64] - bool run_medium = false; // any group needs Tier 2: (T0_64, T2] + bool run_small = false; // SMALL band: (OVO_WARP_MAX, OVO_SMALL_MAX] + bool run_medium = false; // MEDIUM band: (OVO_SMALL_MAX, OVO_MEDIUM_MAX] bool above_medium = false; // at least one group exceeds OVO_MEDIUM_MAX int large_padded = 0; int large_tpb = 0; @@ -101,12 +102,12 @@ static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { } if (n_groups == 0) c.min_grp_size = 0; - // run_warp: Tier 0 kernel is worth running (at least one group small + // run_warp: WARP kernel is worth running (at least one group small // enough to benefit from the warp path). c.run_warp = (c.min_grp_size <= OVO_WARP_MAX); - // above_warp: at least one group needs a non-Tier-0 kernel. + // above_warp: at least one group needs a non-WARP kernel. c.above_warp = (c.max_grp_size > OVO_WARP_MAX); - // run_large: the fused smem-sort fast path (for groups > T0 but ≤ T1). + // run_large: the fused smem-sort fast path (groups > WARP but ≤ LARGE). c.run_large = c.above_warp && (c.max_grp_size <= OVO_LARGE_MAX); if (c.run_large) { c.large_padded = 1; @@ -115,8 +116,8 @@ static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { c.large_smem = (size_t)c.large_padded * sizeof(float) + WARP_REDUCE_BUF * sizeof(double); // Adapt to the device: if the fused-sort buffer would exceed the - // per-block shared-memory limit, fall back to the tier-3 CUB segmented - // sort (which has no smem cap) rather than launching a kernel that + // per-block shared-memory limit, fall back to the HUGE-band CUB + // segmented sort (no smem cap) rather than launching a kernel that // would fail. Never triggers at the current threshold (~16.6KB), but // keeps the dispatch correct if the threshold or device limit changes. if (c.large_smem > wilcoxon_max_smem_per_block()) { @@ -138,7 +139,7 @@ static std::vector make_sort_group_ids(const int* h_grp_offsets, return ids; } -// Tier 0 kernel launcher: 8 warps × 32 threads per block, one (col, group) +// WARP kernel launcher: 8 warps × 32 threads per block, one (col, group) // pair per warp. grid.y covers ceil(K/8) pair rows. static inline void launch_ovo_warp(const float* ref_sorted, const float* grp_dense, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 04c1ce63..6f8d90df 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -310,10 +310,16 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( int seg_start = col_seg_offsets[col]; int seg_end = col_seg_offsets[col + 1]; + // Packed layout matching cast_accumulate_smem_config, which sizes the + // dynamic smem as (1 + compute_sq_sums + compute_nnz) * n_groups doubles. + // s_nnz must follow only the arrays that are actually present: using a + // fixed 2*n_groups offset over-runs the allocation when sq-sums is off but + // nnz is on (the host OVR pts path), corrupting/faulting at larger + // n_groups. extern __shared__ double smem[]; double* s_sum = smem; double* s_sq = smem + n_groups; - double* s_nnz = smem + 2 * n_groups; + double* s_nnz = smem + (compute_sq_sums ? 2 : 1) * n_groups; for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { s_sum[g] = 0.0; diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index b1cb3298..c052aea8 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -877,6 +877,93 @@ def test_wilcoxon_ovr_many_groups_gmem_formats_agree(tie_correct): ) +# Regression guard for a shared-memory OOB write in the host sparse OVR +# cast-and-accumulate kernel: it placed the per-group nnz accumulator at a fixed +# 2*n_groups smem offset, but cast_accumulate_smem_config packs only the enabled +# arrays -- and the host OVR path runs with sq-sums OFF, nnz ON (pts=True). The +# overrun was benign at tiny n_groups (it landed in rounded smem slack, and the +# write/read used the same wrong offset so values stayed self-consistent) but +# caused an illegal memory access once n_groups grew past ~25. n_groups=50 + +# pts=True is the faulting regime, with the smem (non-gmem) accumulator still +# selected. Covers both host sparse formats (the ones that crashed) plus the +# dense/device formats for full parity. +@pytest.mark.parametrize( + "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] +) +def test_wilcoxon_ovr_pts_many_groups_match_scanpy(fmt): + adata_gpu = _make_sized_groups_adata([40] * 50, n_genes=8, seed=4) + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "use_raw": False, + "reference": "rest", + "tie_correct": True, + "pts": True, + "n_genes": 8, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu = adata_gpu.uns["rank_genes_groups"] + cpu = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in gpu[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu[field][group], dtype=float), + np.asarray(cpu[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + gpu_pts, cpu_pts = gpu["pts"], cpu["pts"] + assert list(gpu_pts.columns) == list(cpu_pts.columns) + for col in gpu_pts.columns: + np.testing.assert_allclose( + gpu_pts[col].values, cpu_pts[col].values, rtol=1e-13, atol=1e-15 + ) + + +# Companion to test_wilcoxon_ovr_many_groups_gmem_formats_agree, with pts=True: +# at gmem scale (n_groups > ~3056) the global cast-accumulate and the +# analytic-zero rank kernel both drive the per-group nnz path. scanpy's +# 3000+-group build is too slow for an in-suite parity check, so we assert every +# storage format agrees, including the pts fraction-expressing matrix. +def test_wilcoxon_ovr_many_groups_gmem_pts_formats_agree(): + adata = _make_sized_groups_adata([26] * 3100, n_genes=6, seed=5) + ref = None + for fmt in ("numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"): + a = adata.copy() + a.X = _to_format(adata.X, fmt) + rsc.tl.rank_genes_groups( + a, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + pts=True, + n_genes=6, + ) + r = a.uns["rank_genes_groups"] + cur = { + field: np.vstack( + [np.asarray(r[field][n], dtype=float) for n in r[field].dtype.names] + ) + for field in ("scores", "pvals") + } + cur["pts"] = r["pts"].values + if ref is None: + ref = cur + continue + for field in ("scores", "pvals", "pts"): + np.testing.assert_allclose( + cur[field], ref[field], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + @pytest.mark.parametrize( ("groups", "reference"), [ From 28bc282d2c063519069878c007e1892f341e0d74 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 5 Jun 2026 21:21:35 +0200 Subject: [PATCH 12/36] update rmm --- pyproject.toml | 26 +++++++++++-------- .../_cuda/wilcoxon/wilcoxon_rmm.cu | 11 +++++--- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e0961f15..3736e38d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,11 @@ requires = [ # Generic isolated source builds default to CUDA 12. CUDA wheel builds # rewrite this to the matching cu12/cu13 package; CUDA 13 source builds # should build in an existing RAPIDS env with --no-build-isolation. - "librmm-cu12>=25.10", + # 25.12 floor: the Wilcoxon scratch allocator uses the resource-ref API + # (get_current_device_resource_ref().allocate_sync), and the flat + # header path only exists from 25.12 + # onward (25.10 kept it under rmm/mr/device/). Builds through RMM 26.06+. + "librmm-cu12>=25.12", ] build-backend = "scikit_build_core.build" @@ -39,19 +43,19 @@ dependencies = [ [project.optional-dependencies] rapids-cu13 = [ "cupy-cuda13x", - "cudf-cu13>=25.10", - "cuml-cu13>=25.10", - "cugraph-cu13>=25.10", - "cuvs-cu13>=25.10", - "librmm-cu13>=25.10", + "cudf-cu13>=25.12", + "cuml-cu13>=25.12", + "cugraph-cu13>=25.12", + "cuvs-cu13>=25.12", + "librmm-cu13>=25.12", ] rapids-cu12 = [ "cupy-cuda12x", - "cudf-cu12>=25.10", - "cuml-cu12>=25.10", - "cugraph-cu12>=25.10", - "cuvs-cu12>=25.10", - "librmm-cu12>=25.10", + "cudf-cu12>=25.12", + "cuml-cu12>=25.12", + "cugraph-cu12>=25.12", + "cuvs-cu12>=25.12", + "librmm-cu12>=25.12", ] doc = [ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu index 94a101e9..a63c5716 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu @@ -2,12 +2,17 @@ #include #include -#include #include +// Use the resource-ref API (`get_current_device_resource_ref()` + +// value-semantic `allocate_sync`/`deallocate_sync`) rather than the raw-pointer +// `get_current_device_resource()`. RMM 26.06 removes both the raw-pointer +// accessor and `` as it migrates to the cccl +// `cuda::mr` resource-concept model. The ref form compiles unchanged from +// RMM 25.12 through 26.06 (and onward), so it covers 26.04+. void* wilcoxon_rmm_allocate(size_t bytes) { try { - return rmm::mr::get_current_device_resource()->allocate_sync(bytes); + return rmm::mr::get_current_device_resource_ref().allocate_sync(bytes); } catch (std::exception const& e) { throw std::runtime_error( std::string("RMM allocation failed in Wilcoxon scratch (") + @@ -16,5 +21,5 @@ void* wilcoxon_rmm_allocate(size_t bytes) { } void wilcoxon_rmm_deallocate(void* ptr, size_t bytes) { - rmm::mr::get_current_device_resource()->deallocate_sync(ptr, bytes); + rmm::mr::get_current_device_resource_ref().deallocate_sync(ptr, bytes); } From c165d9fb713d87ec0d6a8db6769f8aa41eb3b5ce Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 5 Jun 2026 21:38:49 +0200 Subject: [PATCH 13/36] update ci buildwheel --- .github/workflows/publish.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 088e3b94..efb3f656 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -98,13 +98,13 @@ jobs: # extension against librmm. Add the matching wheel to the isolated # PEP 517 build requirements after selecting the CUDA package variant. for dep in ( - f' "librmm-cu{other}>=25.10",\n', - f' "rmm-cu{other}>=25.10",\n', + f' "librmm-cu{other}>=25.12",\n', + f' "rmm-cu{other}>=25.12",\n', ): text = text.replace(dep, "") - rmm_build_req = f' "librmm-cu{cuda}>=25.10",\n' + rmm_build_req = f' "librmm-cu{cuda}>=25.12",\n' build_system_text = text.split("[project]", 1)[0] - if f'"librmm-cu{cuda}>=25.10"' not in build_system_text: + if f'"librmm-cu{cuda}>=25.12"' not in build_system_text: text = text.replace( ']\nbuild-backend = "scikit_build_core.build"', f'{rmm_build_req}]\nbuild-backend = "scikit_build_core.build"', @@ -164,7 +164,6 @@ jobs: echo "[rsc-build] marker=$(cat build/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" # Exclude CUDA libs by SONAME glob (auditwheel >=6.2): the runtime # stack (CuPy / nvidia-* wheels) provides them. Globs are version # agnostic -- cusolver's SONAME is libcusolver.so.11 on CUDA 12 but @@ -172,7 +171,10 @@ jobs: # major would graft the wrong (or no) lib. cusolver's transitive deps # (cublasLt, cusparse ~186MB, nvJitLink) are reached by auditwheel's # tree walk and must each be excluded or they bloat the wheel. - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude 'libcublas.so.*' --exclude 'libcublasLt.so.*' --exclude 'libcudart.so.*' --exclude 'libcusolver.so.*' --exclude 'libcusparse.so.*' --exclude 'libnvJitLink.so.*' -w {dest_dir} {wheel}" + # librmm.so / librapids_logger.so are also excluded: they are NOT in + # the CuPy/nvidia stack but are provided by the librmm / rapids_logger + # wheels at runtime, so we must not bundle them into our wheel. + CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude 'libcublas.so.*' --exclude 'libcublasLt.so.*' --exclude 'libcudart.so.*' --exclude 'libcusolver.so.*' --exclude 'libcusparse.so.*' --exclude 'libnvJitLink.so.*' --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" CIBW_BUILD_VERBOSITY: "1" - uses: actions/upload-artifact@v7 From e8a0ba04d24bb6d1b27c84c609b95f477815f00f Mon Sep 17 00:00:00 2001 From: Intron7 Date: Sat, 6 Jun 2026 11:08:39 +0200 Subject: [PATCH 14/36] add csr densification columnwise clean up rmm --- CMakeLists.txt | 1 + docker/Dockerfile | 15 +- src/rapids_singlecell/_cuda/__init__.py | 1 + .../_cuda/rank_genes/csr_tile_to_dense.cuh | 35 +++++ .../_cuda/rank_genes/rank_stats.cu | 146 ++++++++++++++++++ .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 1 - .../tools/_rank_genes_groups/_core.py | 108 ++++--------- .../tools/_rank_genes_groups/_utils.py | 74 ++++++--- .../tools/_rank_genes_groups/_wilcoxon.py | 26 +++- .../_rank_genes_groups/_wilcoxon_binned.py | 19 ++- tests/test_rank_genes_groups_wilcoxon.py | 42 +++-- .../test_rank_genes_groups_wilcoxon_binned.py | 34 +++- 12 files changed, 370 insertions(+), 132 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh create mode 100644 src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 31a15188..3daaab64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,6 +195,7 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) target_sources(_wilcoxon_sparse_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) target_link_libraries(_wilcoxon_sparse_cuda PRIVATE rmm::rmm) + add_nb_cuda_module(_rank_stats_cuda src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) diff --git a/docker/Dockerfile b/docker/Dockerfile index cc533e46..344811a3 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,6 +5,11 @@ ARG GIT_ID=main SHELL ["/bin/bash", "-euo", "pipefail", "-c"] ENV PATH=/opt/conda/bin:$PATH +# Point CMake's find_package(rmm) at the conda env. The conda RAPIDS env resolved +# librmm + cuda-version together, so its librmm/rapids_logger headers match the +# image's CUDA toolkit. This is what lets the --no-build-isolation build below +# pick up the CUDA-matched librmm instead of a mismatched PyPI wheel. +ENV CMAKE_PREFIX_PATH=/opt/conda ARG CUDA_ARCHS="75-real;80-real;86-real;89-real;90-real;100-real;120" RUN < "cudaDevAttr* has no global scope" +# errors on both cu12 (toolkit older than the latest librmm) and cu13 (wrong +# cu12 variant). Install the PEP 517 backend deps first since isolation is off; +# the conda env already provides the librmm/rapids_logger headers + cmake config. +/opt/conda/bin/python -m pip install --no-cache-dir scikit-build-core nanobind setuptools-scm cmake ninja +/opt/conda/bin/python -m pip install --no-cache-dir --no-build-isolation -e . EOF diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index b897c42d..886535df 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -43,6 +43,7 @@ "_pv_cuda", "_qc_cuda", "_qc_dask_cuda", + "_rank_stats_cuda", "_scale_cuda", "_sparse2dense_cuda", "_spca_cuda", diff --git a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh new file mode 100644 index 00000000..1ff8cf11 --- /dev/null +++ b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh @@ -0,0 +1,35 @@ +#pragma once + +#include + +// CSR-slice + densify in a single pass: scatter the nonzeros of column window +// [col_lb, col_ub) straight into a dense (n_cells, col_ub-col_lb) F-order +// (column-major) double buffer. This skips the CSR -> CSC tile rebuild that a +// `X[:, lb:ub].tocsc()` densify would do. +// +// `out` must be pre-zeroed; the atomicAdd accumulation also makes the result +// correct for uncanonicalized / duplicate column indices (matching scipy's +// sum_duplicates semantics). Output is always double to match the rank_genes +// basic-stats path; the input data dtype is templated. + +template +__global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, + const IndexT* __restrict__ indices, + const TData* __restrict__ data, + double* __restrict__ out, int col_lb, + int col_ub, int n_cells) { + const int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_cells) { + return; + } + const long long row_start = static_cast(indptr[row]); + const long long row_end = static_cast(indptr[row + 1]); + for (long long k = row_start; k < row_end; ++k) { + const int col = static_cast(indices[k]); + if (col >= col_lb && col < col_ub) { + atomicAdd( + &out[static_cast(col - col_lb) * n_cells + row], + static_cast(data[k])); + } + } +} diff --git a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu new file mode 100644 index 00000000..b403ab1f --- /dev/null +++ b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu @@ -0,0 +1,146 @@ +#include + +#include "../nb_types.h" +#include "csr_tile_to_dense.cuh" + +using namespace nb::literals; + +namespace { + +constexpr int GROUP_STATS_BLOCK = 256; + +// Benjamini-Hochberg step-up tail: in-place reverse cumulative minimum along +// each row (group) of an already BH-scaled, p-value-sorted matrix. NaNs are +// treated as 1.0. One block per row, single thread per row (serial scan). +__global__ void fdr_bh_reverse_cummin_kernel(double* values, const int n_cols) { + const int row = blockIdx.x; + double running = 1.0; + double* row_values = values + static_cast(row) * n_cols; + for (int col = n_cols - 1; col >= 0; --col) { + double value = row_values[col]; + if (!(value == value)) { // NaN -> 1.0 + value = 1.0; + } + if (value < running) { + running = value; + } + row_values[col] = running; + } +} + +// Per-group sum / sum-of-squares / nnz over a dense F-order (column-major) +// block of shape (n_rows x n_cols). group_codes maps each row to a group; rows +// with an out-of-range code are skipped. Outputs are (n_groups x n_cols), +// C-order, accumulated with atomics. Grid-strided so a chunk larger than the +// gridDim.x cap is still fully covered. +__global__ void group_chunk_stats_kernel( + const double* block, const int* group_codes, double* group_sums, + double* group_sum_sq, double* group_nnz, const int n_rows, const int n_cols, + const int n_groups, const bool compute_nnz) { + const long long total = static_cast(n_rows) * n_cols; + const long long stride = static_cast(blockDim.x) * gridDim.x; + for (long long idx = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + idx < total; idx += stride) { + const int row = idx % n_rows; + const int col = idx / n_rows; + const int group = group_codes[row]; + if (group < 0 || group >= n_groups) { + continue; + } + const double value = block[idx]; + const long long out = static_cast(group) * n_cols + col; + atomicAdd(group_sums + out, value); + atomicAdd(group_sum_sq + out, value * value); + if (compute_nnz && value != 0.0) { + atomicAdd(group_nnz + out, 1.0); + } + } +} + +} // namespace + +// CSR -> dense F-order (double) window densify, in a single fused pass. +template +static void def_csr_tile_to_dense(nb::module_& m) { + m.def( + "csr_tile_to_dense", + [](gpu_array_c indptr, + gpu_array_c indices, + gpu_array_c data, + gpu_array_f out, int col_lb, int col_ub, + std::uintptr_t stream) { + const int n_cells = static_cast(indptr.shape(0)) - 1; + if (n_cells <= 0 || col_ub <= col_lb) { + return; + } + constexpr int CSR_TILE_BLOCK = 128; + const unsigned int grid = + (static_cast(n_cells) + CSR_TILE_BLOCK - 1) / + CSR_TILE_BLOCK; + csr_tile_to_dense_kernel + <<>>( + indptr.data(), indices.data(), data.data(), out.data(), + col_lb, col_ub, n_cells); + CUDA_CHECK_LAST_ERROR(csr_tile_to_dense_kernel); + }, + "indptr"_a, "indices"_a, "data"_a, "out"_a, nb::kw_only(), "col_lb"_a, + "col_ub"_a, "stream"_a = 0); +} + +template +void register_bindings(nb::module_& m) { + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + + m.def( + "fdr_bh_reverse_cummin", + [](gpu_array_c values, std::uintptr_t stream) { + const int n_rows = static_cast(values.shape(0)); + const int n_cols = static_cast(values.shape(1)); + if (n_rows <= 0 || n_cols <= 0) { + return; + } + fdr_bh_reverse_cummin_kernel<<>>( + values.data(), n_cols); + CUDA_CHECK_LAST_ERROR(fdr_bh_reverse_cummin_kernel); + }, + "values"_a, nb::kw_only(), "stream"_a = 0); + + m.def( + "group_chunk_stats", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c group_sums, + gpu_array_c group_sum_sq, + gpu_array_c group_nnz, bool compute_nnz, + std::uintptr_t stream) { + const int n_rows = static_cast(block.shape(0)); + const int n_cols = static_cast(block.shape(1)); + const int n_groups = static_cast(group_sums.shape(0)); + const long long total = static_cast(n_rows) * n_cols; + if (total <= 0) { + return; + } + const unsigned int grid = strided_grid(total, GROUP_STATS_BLOCK); + group_chunk_stats_kernel<<>>( + block.data(), group_codes.data(), group_sums.data(), + group_sum_sq.data(), group_nnz.data(), n_rows, n_cols, n_groups, + compute_nnz); + CUDA_CHECK_LAST_ERROR(group_chunk_stats_kernel); + }, + "block"_a, "group_codes"_a, "group_sums"_a, "group_sum_sq"_a, + "group_nnz"_a, nb::kw_only(), "compute_nnz"_a, "stream"_a = 0); +} + +NB_MODULE(_rank_stats_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index deb2f395..439eafcf 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -423,7 +423,6 @@ static void ovo_streaming_csr_host_impl( checked_int_span(pk.nnz, "OVO host CSR pack compacted nnz"); if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; } - int max_group_rows = max_pack_rows; size_t max_sub_items = (size_t)max_pack_items; if (max_pack_rows == 0) return; diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index ba144f4f..b093c168 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -13,63 +13,8 @@ from rapids_singlecell.get._aggregated import Aggregate from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import EPS, _check_sparse_nonnegative, _select_groups - -_FDR_BH_REVERSE_CUMMIN_KERNEL = cp.RawKernel( - r""" -extern "C" __global__ void fdr_bh_reverse_cummin(double* values, const int n_cols) { - const int row = blockIdx.x; - double running = 1.0; - double* row_values = values + static_cast(row) * n_cols; - for (int col = n_cols - 1; col >= 0; --col) { - double value = row_values[col]; - if (!(value == value)) { - value = 1.0; - } - if (value < running) { - running = value; - } - row_values[col] = running; - } -} -""", - "fdr_bh_reverse_cummin", -) -_GROUP_CHUNK_STATS_KERNEL = cp.RawKernel( - r""" -extern "C" __global__ void group_chunk_stats( - const double* block, - const int* group_codes, - double* group_sums, - double* group_sum_sq, - double* group_nnz, - const int n_rows, - const int n_cols, - const int n_groups, - const bool compute_nnz -) { - const long long idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - const long long total = static_cast(n_rows) * n_cols; - if (idx >= total) { - return; - } - const int row = idx % n_rows; - const int col = idx / n_rows; - const int group = group_codes[row]; - if (group < 0 || group >= n_groups) { - return; - } - const double value = block[idx]; - const long long out = static_cast(group) * n_cols + col; - atomicAdd(group_sums + out, value); - atomicAdd(group_sum_sq + out, value * value); - if (compute_nnz && value != 0.0) { - atomicAdd(group_nnz + out, 1.0); - } -} -""", - "group_chunk_stats", -) +from ._utils import EPS, _reject_complex, _select_groups, _sparse_has_negative + _RANK_SORT_MIN_ELEMENTS = 1_000_000 _RANK_SORT_MAX_WORKERS = 64 @@ -154,8 +99,6 @@ def __init__( self.X = self.X[:, mask_var] self.var_names = self.var_names[mask_var] - _check_sparse_nonnegative(self.X) - self.pre_load = pre_load self.ireference = None @@ -181,6 +124,7 @@ def __init__( self.pts_rest: np.ndarray | None = None self.stats_arrays: dict[str, object] | None = None + self._sparse_negative_fallback = False self._store_wilcoxon_gpu_result = False self._wilcoxon_gpu_result: ( tuple[np.ndarray, cp.ndarray, cp.ndarray, cp.ndarray | None] | None @@ -324,23 +268,16 @@ def _accumulate_chunk_stats_vs_rest( group_nnz = ( cp.zeros((n_groups, n_cols), dtype=cp.float64) if self.comp_pts else None ) - n_items = n_cells * n_cols - threads = 256 - blocks = (n_items + threads - 1) // threads - _GROUP_CHUNK_STATS_KERNEL( - (blocks,), - (threads,), - ( - block, - group_codes_dev, - group_sums, - group_sum_sq, - group_nnz if group_nnz is not None else group_sums, - np.int32(n_cells), - np.int32(n_cols), - np.int32(n_groups), - self.comp_pts, - ), + from rapids_singlecell._cuda import _rank_stats_cuda as _rs + + _rs.group_chunk_stats( + block, + group_codes_dev, + group_sums, + group_sum_sq, + group_nnz if group_nnz is not None else group_sums, + compute_nnz=bool(self.comp_pts), + stream=cp.cuda.get_current_stream().ptr, ) # Means @@ -490,6 +427,17 @@ def compute_statistics( **kwds, ) -> None: """Compute statistics for all groups.""" + # The optimized sparse Wilcoxon paths inject implicit zeros analytically + # as a tie at the column minimum (valid only for nonnegative data). + # t-test/logreg are mean/variance/model-based and sign-agnostic. For the + # Wilcoxon methods we reject complex input and, when sparse data holds + # negatives, fall back to the dense full-sort ranking (correct for any + # sign) rather than erroring -- so e.g. signed sparse data still ranks + # correctly, just via the dense path. + self._sparse_negative_fallback = False + if method in {"wilcoxon", "wilcoxon_binned"}: + _reject_complex(self.X) + self._sparse_negative_fallback = _sparse_has_negative(self.X) if self.pre_load or method in { "t-test", "t-test_overestim_var", @@ -623,10 +571,10 @@ def _fdr_bh_matrix_gpu(pvals: cp.ndarray) -> cp.ndarray: corrected_sorted *= corrected_sorted.shape[1] / cp.arange( 1, corrected_sorted.shape[1] + 1, dtype=cp.float64 ) - _FDR_BH_REVERSE_CUMMIN_KERNEL( - (corrected_sorted.shape[0],), - (1,), - (corrected_sorted, np.int32(corrected_sorted.shape[1])), + from rapids_singlecell._cuda import _rank_stats_cuda as _rs + + _rs.fdr_bh_reverse_cummin( + corrected_sorted, stream=cp.cuda.get_current_stream().ptr ) corrected = cp.empty_like(corrected_sorted) cp.put_along_axis(corrected, order, corrected_sorted, axis=1) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index e9efbc50..360928e7 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -19,23 +19,8 @@ MIN_GROUP_SIZE_WARNING = 25 -def _nonnegative_error(prefix: str) -> ValueError: - msg = ( - f"{prefix} contains negative values. rank_genes_groups expects " - "nonnegative expression values; use raw counts or log1p/log-normalized " - "expression, not scaled or centered data." - ) - return ValueError(msg) - - -def _check_sparse_nonnegative(X) -> None: - """Reject inputs with negative values where an eager check is cheap. - - Sparse rank_genes_groups code treats missing entries as true expression - zeros. Optimized sparse Wilcoxon paths may rank explicit nonzeros and add - implicit zeros analytically, which is only valid when explicit sparse - values are nonnegative expression values. - """ +def _reject_complex(X) -> None: + """Reject complex expression values (unsupported by every rank method).""" dtype = None if sp.issparse(X) or cpsp.issparse(X): dtype = np.dtype(X.data.dtype) @@ -45,12 +30,21 @@ def _check_sparse_nonnegative(X) -> None: msg = "rank_genes_groups does not support complex expression values." raise TypeError(msg) - if sp.issparse(X): - if X.nnz > 0 and float(X.data.min()) < 0: - raise _nonnegative_error("Sparse input") - elif cpsp.issparse(X): - if X.nnz > 0 and float(X.data.min()) < 0: - raise _nonnegative_error("Sparse input") + +def _sparse_has_negative(X) -> bool: + """Whether X is a sparse matrix holding an explicit negative value. + + The optimized sparse Wilcoxon paths rank explicit nonzeros and add the + implicit (structural) zeros analytically as a tie at the column minimum, + which is correct only when every stored value is nonnegative (counts / + log1p-normalized data). With a negative stored value the implicit zeros are + no longer the minimum, so that analytic ranking is wrong and the caller + must fall back to the dense full-sort path (valid for any sign). Dense + inputs and the t-test/logreg methods never need this. + """ + if sp.issparse(X) or cpsp.issparse(X): + return X.nnz > 0 and float(X.data.min()) < 0 + return False def _select_groups( @@ -174,9 +168,43 @@ def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray return _sparse_to_dense(csc_chunk, order="F").astype(cp.float64) +def _csr_tile_to_dense_block(X, start: int, stop: int) -> cp.ndarray: + """Densify a CSR column window [start, stop) straight into an F-order + float64 block via a single fused CSR->dense kernel, skipping the CSR->CSC + tile rebuild that ``X[:, start:stop].tocsc()`` (host) / ``X[:, start:stop]`` + (device) would do. For device CSR the index arrays are already on the GPU, + so there is no transfer. + """ + from rapids_singlecell._cuda import _rank_stats_cuda as _rs + + n_rows = X.shape[0] + out = cp.zeros((n_rows, stop - start), dtype=cp.float64, order="F") + if X.nnz == 0: + return out + _rs.csr_tile_to_dense( + cp.asarray(X.indptr), + cp.asarray(X.indices), + cp.asarray(X.data), + out, + col_lb=int(start), + col_ub=int(stop), + stream=cp.cuda.get_current_stream().ptr, + ) + return out + + def _get_column_block(X, start: int, stop: int) -> cp.ndarray: """Extract a column block as a dense F-order float64 CuPy array.""" match X: + # Device CSR: the fused csr_tile_to_dense kernel densifies the window in + # one pass with no transfer (index arrays are already on the GPU) -- the + # big win. Host CSR is intentionally NOT routed here: doing so would + # re-transfer the whole CSR every chunk (only ~1.15x and worse with more + # chunks); host data should be moved to the device once upstream + # (`X_to_GPU`) so it lands in this fast device branch, otherwise it falls + # through to the `.tocsc()` path below. + case cpsp.csr_matrix(): + return _csr_tile_to_dense_block(X, start, stop) case sp.csc_matrix() | sp.csc_array(): return _csc_columns_to_gpu(X, start, stop, X.shape[0]) case sp.spmatrix() | sp.sparray(): diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 2cef7665..923641ce 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -71,9 +71,17 @@ def _maybe_preload_host_dense(rg: _RankGenes) -> None: def _get_dense_column_block_f32(X, start: int, stop: int) -> cp.ndarray: - """Extract a dense column block as F-order float32 CuPy memory.""" + """Extract a dense column block as F-order float32 CuPy memory. + + For sparse X (the negative-values dense fallback) the column window is + densified on the fly via the shared CSR/CSC densify path, so no full-matrix + dense materialization happens. + """ if isinstance(X, np.ndarray | cp.ndarray): return cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") + if sp.issparse(X) or cpsp.issparse(X): + block = _get_column_block(X, start, stop) # float64 F-order chunk + return cp.asfortranarray(block.astype(cp.float32, copy=False)) raise TypeError(f"Expected dense matrix, got {type(X)}") @@ -539,7 +547,9 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + host_sparse = ( + isinstance(X, sp.spmatrix | sp.sparray) and not rg._sparse_negative_fallback + ) if host_sparse: if X.format not in {"csr", "csc"}: raise TypeError( @@ -655,7 +665,9 @@ def _wilcoxon_vs_rest( p_host = p_values.get() return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] - if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): + if ( + cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X) + ) and not rg._sparse_negative_fallback: data, indices, indptr = _device_sparse_arrays_f32(X) group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) @@ -847,7 +859,9 @@ def _wilcoxon_with_reference( ) ) - host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + host_sparse = ( + isinstance(X, sp.spmatrix | sp.sparray) and not rg._sparse_negative_fallback + ) if host_sparse: if X.format not in {"csr", "csc"}: raise TypeError( @@ -996,7 +1010,9 @@ def _wilcoxon_with_reference( for slot, group_index in enumerate(test_group_indices) ] - if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): + if ( + cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X) + ) and not rg._sparse_negative_fallback: sparse_X = X if cpsp.isspmatrix_csr(sparse_X) and not sparse_X.has_sorted_indices: sparse_X = sparse_X.copy() diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index 14793834..d5f4ed0d 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -11,7 +11,7 @@ from rapids_singlecell._compat import DaskArray from rapids_singlecell._cuda import _wilcoxon_binned_cuda as _wb -from ._utils import MIN_GROUP_SIZE_WARNING +from ._utils import MIN_GROUP_SIZE_WARNING, _get_column_block if TYPE_CHECKING: from numpy.typing import NDArray @@ -212,6 +212,7 @@ def wilcoxon_binned( "tie_correct": tie_correct, "use_continuity": use_continuity, "ireference": ireference, + "force_dense": rg._sparse_negative_fallback, } # Pre-allocate output @@ -258,13 +259,27 @@ def process_gene_batch( tie_correct: bool = False, use_continuity: bool = False, ireference: int | None = None, + force_dense: bool = False, ) -> tuple[cp.ndarray, cp.ndarray]: """Process one gene batch, dispatching on Dask vs in-memory.""" n_hist_groups = n_cells_per_group_hist.shape[0] n_genes_batch = stop - start is_sparse = False - if isinstance(X, DaskArray): + if force_dense and cpsp.issparse(X): + # Negative-values fallback: the sparse histogram assigns implicit zeros + # to bin 0, which is correct only for nonnegative data. Densify the + # column window (chunked, no full materialization) and use the dense + # histogram, whose bins span the full [min, max] range. + hist = _launch_dense( + _get_column_block(X, start, stop), + group_codes, + n_hist_groups, + n_bins=n_bins, + bin_low=bin_low, + inv_bin_width=inv_bin_width, + ) + elif isinstance(X, DaskArray): hist = _process_dask( X, start=start, diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index c052aea8..5abcda84 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -33,31 +33,49 @@ def _make_nonnegative(adata): return adata +# The optimized sparse Wilcoxon paths inject implicit zeros analytically as a tie +# at the column minimum, which is valid only for nonnegative data. With negatives +# present they must NOT be used; instead the ranking falls back to the dense +# full-sort path (correct for any sign), so the result matches running the same +# method on the dense matrix. (t-test/t-test_overestim_var/logreg never need this +# and accept signed sparse data directly -- e.g. mixscape's LDA t-test.) @pytest.mark.parametrize( "method", - ["t-test", "t-test_overestim_var", "wilcoxon", "wilcoxon_binned", "logreg"], + ["wilcoxon", "wilcoxon_binned"], ) @pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) -def test_rank_genes_groups_sparse_negative_values_raise(method, fmt): +def test_rank_genes_groups_sparse_negative_values_fallback(method, fmt): X = np.array( [ [-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [2.0, 0.0, 1.0], [0.0, 3.0, 0.0], + [-2.0, 1.0, 0.0], + [1.0, 0.0, 3.0], ], - dtype=np.float32, - ) - adata = sc.AnnData( - X=_to_format(X, fmt), - obs=pd.DataFrame( - {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} - ), - var=pd.DataFrame(index=["g0", "g1", "g2"]), + dtype=np.float64, ) + obs = pd.DataFrame({"group": pd.Categorical(list("aaabbb"), categories=["a", "b"])}) + var = pd.DataFrame(index=["g0", "g1", "g2"]) + + sparse_adata = sc.AnnData(X=_to_format(X, fmt), obs=obs.copy(), var=var.copy()) + dense_fmt = "cupy_dense" if fmt.startswith("cupy") else "numpy_dense" + dense_adata = sc.AnnData(X=_to_format(X, dense_fmt), obs=obs.copy(), var=var.copy()) + + rsc.tl.rank_genes_groups(sparse_adata, "group", method=method, use_raw=False) + rsc.tl.rank_genes_groups(dense_adata, "group", method=method, use_raw=False) - with pytest.raises(ValueError, match="Sparse input contains negative values"): - rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) + # Sparse-with-negatives falls back to the dense ranking -> identical result. + sp_scores = sparse_adata.uns["rank_genes_groups"]["scores"] + dn_scores = dense_adata.uns["rank_genes_groups"]["scores"] + for group in sp_scores.dtype.names: + np.testing.assert_allclose( + np.asarray(sp_scores[group], dtype=float), + np.asarray(dn_scores[group], dtype=float), + rtol=1e-13, + atol=1e-13, + ) @pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_dense", "cupy_csr"]) diff --git a/tests/test_rank_genes_groups_wilcoxon_binned.py b/tests/test_rank_genes_groups_wilcoxon_binned.py index 85abc3e2..f0e6848d 100644 --- a/tests/test_rank_genes_groups_wilcoxon_binned.py +++ b/tests/test_rank_genes_groups_wilcoxon_binned.py @@ -428,21 +428,39 @@ def test_sparse_with_actual_zeros(self, adata_blobs): assert np.all(pvals >= 0) assert np.all(pvals <= 1) - def test_sparse_negative_values_raises(self, adata_blobs): - """Sparse input with negative values should raise ValueError.""" + def test_sparse_negative_values_fallback(self, adata_blobs): + """Sparse input with negatives must not use the sparse histogram (which + assigns implicit zeros to bin 0, valid only for nonnegative data); it + falls back to the dense histogram, so the result matches the dense run. + """ import cupy as cp import cupyx.scipy.sparse as cpsp adata = adata_blobs.copy() rsc.get.anndata_to_GPU(adata) - # Make sparse with negative values - dense = cp.array(adata.X) + dense = cp.asarray(adata.X, dtype=cp.float64) dense[:, 0] = -1.0 - adata.X = cpsp.csr_matrix(dense) - with pytest.raises(ValueError, match="Sparse input contains negative values"): - rsc.tl.rank_genes_groups( - adata, "blobs", method="wilcoxon_binned", use_raw=False + sparse_adata = adata.copy() + sparse_adata.X = cpsp.csr_matrix(dense) + dense_adata = adata.copy() + dense_adata.X = dense + + rsc.tl.rank_genes_groups( + sparse_adata, "blobs", method="wilcoxon_binned", use_raw=False + ) + rsc.tl.rank_genes_groups( + dense_adata, "blobs", method="wilcoxon_binned", use_raw=False + ) + + sp_scores = sparse_adata.uns["rank_genes_groups"]["scores"] + dn_scores = dense_adata.uns["rank_genes_groups"]["scores"] + for group in sp_scores.dtype.names: + np.testing.assert_allclose( + np.asarray(sp_scores[group], dtype=float), + np.asarray(dn_scores[group], dtype=float), + rtol=1e-13, + atol=1e-13, ) def test_log1p_warning(self, adata_blobs): From df988d84a4371f8ceb2f0fe8e0f48aff3fbbaaa3 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 8 Jun 2026 23:09:51 +0200 Subject: [PATCH 15/36] fix docker --- .github/workflows/docker.yml | 2 +- conda/rsc_rapids_26.04_cuda12.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 3f24105b..e76070b6 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -73,7 +73,7 @@ jobs: RAPIDS_VER: - "26.04" CUDA_SUFFIX: - - { ver: "12.8.0", label: "cuda12", pkg: "cu12" } + - { ver: "12.9.1", label: "cuda12", pkg: "cu12" } - { ver: "13.0.2", label: "cuda13", pkg: "cu13" } name: Build Docker images (${{ matrix.CUDA_SUFFIX.label }}) runs-on: ubuntu-latest diff --git a/conda/rsc_rapids_26.04_cuda12.yml b/conda/rsc_rapids_26.04_cuda12.yml index 537b365a..f0010a8b 100644 --- a/conda/rsc_rapids_26.04_cuda12.yml +++ b/conda/rsc_rapids_26.04_cuda12.yml @@ -7,7 +7,7 @@ channels: dependencies: - rapids=26.04 - python=3.14 - - cuda-version=12.8 + - cuda-version=12.9 - cudnn - cutensor - cusparselt From 82703d22f3e864fa935117d592eafd88bded0973 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 9 Jun 2026 01:31:05 +0200 Subject: [PATCH 16/36] fix issues --- CMakeLists.txt | 28 ++++++++++++ docs/contributing.md | 18 ++++++++ docs/release-notes/{0.15.3.md => 0.16.0.md} | 3 +- docs/release-notes/index.md | 6 ++- hatch.toml | 8 ++++ pyproject.toml | 9 +++- src/rapids_singlecell/_cuda/nb_types.h | 13 ++++++ .../_cuda/rank_genes/csr_tile_to_dense.cuh | 13 +++--- .../_cuda/rank_genes/rank_stats.cu | 42 ++++++++++++++++++ .../_cuda/wilcoxon/wilcoxon.cu | 18 ++++---- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 14 +++--- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 22 ++++++---- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 20 +++++---- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 9 ++-- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 6 +++ .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 43 +++++++++++-------- .../_cuda/wilcoxon/wilcoxon_sparse.cu | 6 +++ .../tools/_rank_genes_groups/_wilcoxon.py | 27 ++++++++++-- 18 files changed, 238 insertions(+), 67 deletions(-) rename docs/release-notes/{0.15.3.md => 0.16.0.md} (78%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3daaab64..1088dbe6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,6 +111,34 @@ if (RSC_BUILD_EXTENSIONS) else() find_package(rmm CONFIG REQUIRED) endif() + + # CCCL 3.3.0 (shipped by RAPIDS 26.04) declares the + # cudaDevAttrHostNumaMemoryPoolsSupported device-attribute specialization under + # `_CCCL_CTK_AT_LEAST(12, 6)`, but the CUDA runtime only added that enum in 12.9. + # So compiling the RMM/CCCL-including TUs (the Wilcoxon scratch allocator) against + # a CUDA 12.6-12.8 toolkit fails with a cryptic + # `error: the global scope has no "cudaDevAttrHostNumaMemoryPoolsSupported"`. + # CCCL fixed the guard to `_CCCL_CTK_AT_LEAST(12, 9)` after 3.3.0 (cccl PR #7838), + # so RAPIDS >= 26.06 (CCCL > 3.3.0) closes the gap -- only flag the buggy CCCL. + # Fail fast with an actionable message. Prebuilt wheels are unaffected: they are + # built on CUDA 12.2 (below the guard), so the enum is never referenced. + set(_rsc_cccl_buggy_numa_guard TRUE) + if (DEFINED CCCL_VERSION AND CCCL_VERSION VERSION_GREATER 3.3.0) + set(_rsc_cccl_buggy_numa_guard FALSE) + endif() + if (NOT RSC_SKIP_CUDA_VERSION_CHECK + AND _rsc_cccl_buggy_numa_guard + AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.6 + AND CUDAToolkit_VERSION VERSION_LESS 12.9) + message(FATAL_ERROR + "Cannot build rapids_singlecell from source with CUDA ${CUDAToolkit_VERSION} against " + "CCCL ${CCCL_VERSION} (RAPIDS 26.04): it references cudaDevAttrHostNumaMemoryPoolsSupported, " + "which the CUDA 12.6-12.8 toolkit does not define (NVIDIA added it in 12.9). " + "Use CUDA >= 12.9 (or <= 12.5), upgrade to RAPIDS >= 26.06 (CCCL > 3.3.0 fixes the guard), " + "or install the prebuilt wheel (pip install rapids-singlecell-cu12). " + "If your toolkit does define this enum, override with -DRSC_SKIP_CUDA_VERSION_CHECK=ON.") + endif() + message(STATUS "Using RMM for CUDA extension scratch allocations") message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() diff --git a/docs/contributing.md b/docs/contributing.md index f0542e57..e68011d2 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -7,6 +7,24 @@ - NVIDIA GPU with CUDA support - [micromamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html), conda/mamba, or [uv](https://docs.astral.sh/uv/) - A RAPIDS environment (e.g., conda `rapids-26.04` or pip-installed RAPIDS) +- **CUDA toolkit ≥ 12.9, or ≤ 12.5, for building from source** (see note below) + +```{important} +**On RAPIDS 26.04, building from source needs CUDA ≥ 12.9 (or ≤ 12.5) on CUDA 12.** +RAPIDS 26.04 ships CCCL 3.3.0, which references the `cudaDevAttrHostNumaMemoryPoolsSupported` +device attribute whenever the toolkit is ≥ 12.6, but NVIDIA only added that enum in +CUDA 12.9. So compiling the RMM/CCCL-using kernels (the Wilcoxon scratch allocator) +against a **CUDA 12.6–12.8** toolkit fails with +`error: the global scope has no "cudaDevAttrHostNumaMemoryPoolsSupported"`. + +This is an upstream CCCL guard bug, **fixed in CCCL > 3.3.0 (RAPIDS ≥ 26.06)** — so +the gap only affects RAPIDS 26.04. CUDA 13.x is unaffected. If you're on RAPIDS 26.04 ++ CUDA 12.6–12.8, either build with CUDA ≥ 12.9 (or ≤ 12.5), upgrade to RAPIDS ≥ 26.06, +or just use the **prebuilt wheel** (`pip install rapids-singlecell-cu12`) — wheels are +built on CUDA 12.2 (below the guard), so the enum is never referenced and they run fine +on any CUDA 12.x runtime, including 12.6–12.8. The build emits an actionable error in +this range; override only if your toolkit defines the enum with `-DRSC_SKIP_CUDA_VERSION_CHECK=ON`. +``` ### Clone and install diff --git a/docs/release-notes/0.15.3.md b/docs/release-notes/0.16.0.md similarity index 78% rename from docs/release-notes/0.15.3.md rename to docs/release-notes/0.16.0.md index 6b3a3f8c..1109d619 100644 --- a/docs/release-notes/0.15.3.md +++ b/docs/release-notes/0.16.0.md @@ -1,7 +1,8 @@ -### 0.15.3 {small}`the-future` +### 0.16.0 {small}`the-future` ```{rubric} Features ``` +* Reworked GPU {func}`~rapids_singlecell.tl.rank_genes_groups` Wilcoxon onto dedicated nanobind CUDA kernels {pr}`636` {smaller}`S Dicks` * Add {class}`~rapids_singlecell.ptg.Mixscape` for GPU-accelerated Mixscape (`perturbation_signature`, `mixscape`, `mixscale`, `lda`) {pr}`688` {smaller}`S Dicks` ```{rubric} Performance diff --git a/docs/release-notes/index.md b/docs/release-notes/index.md index 329eb0ed..1f01cc8a 100644 --- a/docs/release-notes/index.md +++ b/docs/release-notes/index.md @@ -3,9 +3,11 @@ # Release notes -## Version 0.15.0 -```{include} /release-notes/0.15.3.md +## Version 0.16.0 +```{include} /release-notes/0.16.0.md ``` + +## Version 0.15.0 ```{include} /release-notes/0.15.2.md ``` ```{include} /release-notes/0.15.1.md diff --git a/hatch.toml b/hatch.toml index d7376c2a..81112df2 100644 --- a/hatch.toml +++ b/hatch.toml @@ -38,6 +38,14 @@ overrides.matrix.deps.extra-dependencies = [ { if = [ "dev", ], value = "scanpy @ git+https://github.com/scverse/scanpy.git" }, + # numpy 2.5 dropped the `np.row_stack` alias that numba-cuda still calls, so + # `import rapids_singlecell` fails on numpy>=2.5. Only the prerelease-allowing + # envs can pull it; cap below 2.5 (`<2.5.0a0` so it also excludes 2.5.0rcN + # under UV_PRERELEASE=allow). Drop once numba-cuda no longer needs the alias. + { if = [ + "dev", + "rapids_prerelease", + ], value = "numpy<2.5.0a0" }, ] overrides.matrix.cuda.extra-dependencies = [ { if = [ "13" ], value = "cuml-cu13<26.8" }, diff --git a/pyproject.toml b/pyproject.toml index 3736e38d..4117e77d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,8 +172,15 @@ sdist.include = [ "src/rapids_singlecell/_version.py" ] # Use abi3audit to catch issues with Limited API wheels [tool.cibuildwheel.linux] +# Exclude CUDA libs by SONAME glob (auditwheel >= 6.2): suffixes are version- +# dependent (cusolver is libcusolver.so.11 on CUDA 12 but .12 on CUDA 13), so a +# `.*` glob stays correct across CUDA majors where hardcoded `.12`/`.13` would +# miss variants and bundle ~186MB (cublasLt, cusparse, nvJitLink) that CuPy +# provides at runtime. librmm.so / librapids_logger.so come from the librmm / +# rapids_logger wheels. Keep this list in sync with CIBW_REPAIR_WHEEL_COMMAND in +# .github/workflows/publish.yml (that env var overrides this block in CI). repair-wheel-command = [ - "auditwheel repair --exclude libcublas.so.12 --exclude libcublas.so.13 --exclude libcublasLt.so.12 --exclude libcublasLt.so.13 --exclude libcudart.so.12 --exclude libcudart.so.13 --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}", + "auditwheel repair --exclude 'libcublas.so.*' --exclude 'libcublasLt.so.*' --exclude 'libcudart.so.*' --exclude 'libcusolver.so.*' --exclude 'libcusparse.so.*' --exclude 'libnvJitLink.so.*' --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}", "pipx run abi3audit --strict --report {wheel}", ] [tool.cibuildwheel.macos] diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index f4daa926..cd4fe4d6 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -21,6 +21,19 @@ inline void cuda_check_last_error(const char* kernel_name) { #define CUDA_CHECK_LAST_ERROR(kernel_name) cuda_check_last_error(#kernel_name) +/// Check a cudaError_t returned directly by a CUDA/CUB API call. +/// Unlike CUDA_CHECK_LAST_ERROR (which inspects cudaGetLastError after a +/// <<<...>>> launch), this validates the status a function call returns -- e.g. +/// cub::DeviceSegmentedRadixSort::SortKeys or cudaStreamSynchronize -- so a +/// failed call surfaces here with a clear label instead of as corrupted output +/// at a later synchronization point. +inline void cuda_check(cudaError_t err, const char* what) { + if (err != cudaSuccess) { + throw std::runtime_error(std::string(what) + + " failed: " + cudaGetErrorString(err)); + } +} + /// Per-axis cached cap on `gridDim.{x,y,z}`. These differ in CUDA: /// gridDim.x: 2^31-1 on CC 3.0+ /// gridDim.y: 65535 on most GPUs diff --git a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh index 1ff8cf11..e39e32e5 100644 --- a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh +++ b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh @@ -24,12 +24,15 @@ __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, } const long long row_start = static_cast(indptr[row]); const long long row_end = static_cast(indptr[row + 1]); + // Keep column ids in IndexT: narrowing a 64-bit IndexT to int would + // truncate large column ids and misplace writes. + const IndexT lb = static_cast(col_lb); + const IndexT ub = static_cast(col_ub); for (long long k = row_start; k < row_end; ++k) { - const int col = static_cast(indices[k]); - if (col >= col_lb && col < col_ub) { - atomicAdd( - &out[static_cast(col - col_lb) * n_cells + row], - static_cast(data[k])); + const IndexT col = indices[k]; + if (col >= lb && col < ub) { + atomicAdd(&out[static_cast(col - lb) * n_cells + row], + static_cast(data[k])); } } } diff --git a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu index b403ab1f..9893af17 100644 --- a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu +++ b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu @@ -74,6 +74,22 @@ static void def_csr_tile_to_dense(nb::module_& m) { if (n_cells <= 0 || col_ub <= col_lb) { return; } + if (col_lb < 0) { + throw std::invalid_argument( + "csr_tile_to_dense: col_lb must be non-negative"); + } + if (indices.shape(0) != data.shape(0)) { + throw std::invalid_argument( + "csr_tile_to_dense: indices and data must have equal " + "length"); + } + if (out.ndim() != 2 || static_cast(out.shape(0)) != n_cells || + static_cast(out.shape(1)) < + static_cast(col_ub) - col_lb) { + throw std::invalid_argument( + "csr_tile_to_dense: out must be a (n_cells, >= col_ub - " + "col_lb) array"); + } constexpr int CSR_TILE_BLOCK = 128; const unsigned int grid = (static_cast(n_cells) + CSR_TILE_BLOCK - 1) / @@ -122,6 +138,12 @@ void register_bindings(nb::module_& m) { gpu_array_c group_sum_sq, gpu_array_c group_nnz, bool compute_nnz, std::uintptr_t stream) { + if (block.ndim() != 2 || group_sums.ndim() != 2 || + group_sum_sq.ndim() != 2) { + throw std::invalid_argument( + "group_chunk_stats: block, group_sums and group_sum_sq " + "must be 2-D"); + } const int n_rows = static_cast(block.shape(0)); const int n_cols = static_cast(block.shape(1)); const int n_groups = static_cast(group_sums.shape(0)); @@ -129,6 +151,26 @@ void register_bindings(nb::module_& m) { if (total <= 0) { return; } + if (static_cast(group_codes.shape(0)) != n_rows) { + throw std::invalid_argument( + "group_chunk_stats: group_codes length must equal block " + "rows"); + } + if (static_cast(group_sum_sq.shape(0)) != n_groups || + static_cast(group_sums.shape(1)) != n_cols || + static_cast(group_sum_sq.shape(1)) != n_cols) { + throw std::invalid_argument( + "group_chunk_stats: group_sums and group_sum_sq must be " + "(n_groups, n_cols)"); + } + if (compute_nnz && + (group_nnz.ndim() != 2 || + static_cast(group_nnz.shape(0)) != n_groups || + static_cast(group_nnz.shape(1)) != n_cols)) { + throw std::invalid_argument( + "group_chunk_stats: group_nnz must be (n_groups, n_cols) " + "when compute_nnz is set"); + } const unsigned int grid = strided_grid(total, GROUP_STATS_BLOCK); group_chunk_stats_kernel<<>>( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 890242c2..4542d634 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -88,10 +88,11 @@ static void launch_ovr_rank_dense_streaming( const float* keys_in = block + (size_t)col * n_rows; size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, - buf.vals_out, sb_items, sb_cols, buf.seg_offsets, - buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream), + "dense OVR segmented sort"); if (use_gmem) { cudaMemsetAsync(buf.sub_rank_sums, 0, @@ -260,10 +261,11 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( const float* grp_sub = grp_data + (size_t)col * n_all_grp; upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); size_t ref_temp = ref_cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.ref_cub_temp, ref_temp, ref_sub, buf.ref_sorted, - sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( + buf.ref_cub_temp, ref_temp, ref_sub, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream), + "dense OVO ref segmented sort"); ref_sub = buf.ref_sorted; OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 11fa4dbc..8a5303db 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -58,9 +58,10 @@ static inline size_t cub_segmented_sortkeys_temp_bytes(int num_items, size_t bytes = 0; auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, fk, fk, num_items, - num_segments, doff, doff + 1, - BEGIN_BIT, END_BIT); + cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, bytes, fk, fk, num_items, num_segments, doff, + doff + 1, BEGIN_BIT, END_BIT), + "CUB SortKeys temp-size query"); return bytes; } @@ -71,9 +72,10 @@ static inline size_t cub_segmented_sortpairs_temp_bytes(int num_items, auto* fk = reinterpret_cast(1); auto* v = reinterpret_cast(1); auto* off = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs(nullptr, bytes, fk, fk, v, v, - num_items, num_segments, off, - off + 1, BEGIN_BIT, END_BIT); + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, bytes, fk, fk, v, v, num_items, num_segments, off, + off + 1, BEGIN_BIT, END_BIT), + "CUB SortPairs temp-size query"); return bytes; } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 91622f71..2e3696fd 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -155,11 +155,14 @@ static void ovo_streaming_csr_impl( cub_segmented_sortkeys_temp_bytes(cache_ref_items_i32, cache_cols); ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); size_t ref_temp = ref_cub_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - ref_cub_temp_buf.data(), ref_temp, d_ref_dense, d_ref_sorted, - cache_ref_items_i32, cache_cols, d_ref_seg_offsets, - d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT, ref_stream); - cudaStreamSynchronize(ref_stream); + cuda_check( + cub::DeviceSegmentedRadixSort::SortKeys( + ref_cub_temp_buf.data(), ref_temp, d_ref_dense, d_ref_sorted, + cache_ref_items_i32, cache_cols, d_ref_seg_offsets, + d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT, ref_stream), + "device CSR OVO ref segmented sort"); + cuda_check(cudaStreamSynchronize(ref_stream), + "device CSR OVO ref sort sync"); int col = cache_col; int cache_stop = cache_col + cache_cols; @@ -354,10 +357,11 @@ static void ovo_streaming_csc_impl( upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); { size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, - sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream), + "device CSC OVO ref segmented sort"); } cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 439eafcf..8ffe1208 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -223,10 +223,11 @@ static void ovo_streaming_csc_host_impl( upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); { size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, - sb_ref_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream), + "host CSC OVO ref segmented sort"); } // ---- Extract grp from CSC via row_map ---- @@ -559,10 +560,13 @@ static void ovo_streaming_csr_host_impl( ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); upload_linear_offsets(d_ref_seg, n_cols, n_ref, ref_stream); size_t temp = ref_cub_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, ref_items_i32, - n_cols, d_ref_seg, d_ref_seg + 1, BEGIN_BIT, END_BIT, ref_stream); - cudaStreamSynchronize(ref_stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, + ref_items_i32, n_cols, d_ref_seg, d_ref_seg + 1, + BEGIN_BIT, END_BIT, ref_stream), + "host CSR OVO ref segmented sort"); + cuda_check(cudaStreamSynchronize(ref_stream), + "host CSR OVO ref sort sync"); } // ref scratch drops here cudaStreamDestroy(ref_stream); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index 9f823aba..f162782a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -266,10 +266,11 @@ static inline void ovo_dispatch_tiers( CUDA_CHECK_LAST_ERROR(build_huge_seg_offsets_kernel); size_t temp = grp_cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - sc.grp_cub_temp, temp, grp_dense, sc.grp_sorted, - sb_grp_items_actual, sb_grp_seg, sc.grp_seg_offsets, - sc.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( + sc.grp_cub_temp, temp, grp_dense, sc.grp_sorted, + sb_grp_items_actual, sb_grp_seg, sc.grp_seg_offsets, + sc.grp_seg_ends, BEGIN_BIT, END_BIT, stream), + "OVO huge-tier group segmented sort"); dim3 grid(sb_cols, n_groups); ovo_rank_huge_kernel<<>>( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index 2b282b0b..67bc1dfe 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -20,6 +20,12 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). * write_pos[c - col_start] must be initialized to the prefix-sum offset * for column c. Each thread atomically claims a unique destination slot. + * + * PRECONDITION: each row's `indices` must be sorted ascending. The binary + * search for col_start and the `break` at col_stop both depend on it; unsorted + * rows would silently drop or misplace nonzeros. Every caller enforces this -- + * the Python dispatch calls `sort_indices()` on the CSR/CSC input before + * invoking the streaming impls that launch this kernel. */ template __global__ void csr_scatter_to_csc_kernel( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index e683fadf..9690e5ce 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -174,11 +174,12 @@ static void ovr_sparse_csc_host_streaming_impl( // CUB sort only stored nonzeros (float32 keys) if (batch_nnz > 0) { size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.d_sparse_data_f32, buf.keys_out, - buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, - buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, - stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_sparse_data_f32, + buf.keys_out, buf.d_sparse_indices, buf.vals_out, + batch_nnz, sb_cols, buf.d_seg_offsets, + buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, stream), + "host CSC OVR segmented sort"); } // Sparse rank kernel (stats already captured above) @@ -448,11 +449,12 @@ static void ovr_sparse_csr_host_streaming_impl( if (batch_nnz > 0) { size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.csc_vals_f32, buf.keys_out, - buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, - buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, END_BIT, - stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals_f32, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, + END_BIT, stream), + "host CSR OVR segmented sort"); } launch_ovr_sparse_rank( @@ -599,11 +601,12 @@ static void ovr_sparse_csc_streaming_impl( // Sort only stored values (keys=data, vals=row_indices) if (batch_nnz > 0) { size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, csc_data + ptr_start, buf.keys_out, - csc_indices + ptr_start, buf.vals_out, batch_nnz, sb_cols, - buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, - stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, csc_data + ptr_start, + buf.keys_out, csc_indices + ptr_start, buf.vals_out, + batch_nnz, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream), + "device CSC OVR segmented sort"); } // Sparse rank kernel (handles implicit zeros analytically) @@ -803,10 +806,12 @@ static void ovr_sparse_csr_streaming_impl( // CUB sort only the nonzeros size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.csc_vals, buf.keys_out, buf.csc_row_idx, - buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, - buf.col_offsets + 1, BEGIN_BIT, END_BIT, stream); + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, + END_BIT, stream), + "device CSR OVR segmented sort"); } // Sparse rank kernel (handles implicit zeros analytically) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index 5cf7a067..4ac0b62b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -30,6 +30,7 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c rank_sums, \ gpu_array_c tie_corr, int n_rows, int n_cols, \ int n_groups, bool compute_tie_corr, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ IMPL(data.data(), indices.data(), indptr.data(), \ group_codes.data(), group_sizes.data(), rank_sums.data(), \ tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, \ @@ -64,6 +65,7 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_group_nnz, int n_rows, int n_cols, \ int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ bool compute_nnz, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovr_sparse_csc_host_streaming_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_group_codes.data(), h_group_sizes.data(), \ @@ -102,6 +104,7 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_group_nnz, int n_rows, int n_cols, \ int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ bool compute_nnz, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovr_sparse_csr_host_streaming_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_group_codes.data(), h_group_sizes.data(), \ @@ -139,6 +142,7 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c tie_corr, int n_ref, int n_all_grp, \ int n_cols, int n_groups, bool compute_tie_corr, \ int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ IMPL(data.data(), indices.data(), indptr.data(), ref_rows.data(), \ grp_rows.data(), grp_offsets.data(), rank_sums.data(), \ tie_corr.data(), n_ref, n_all_grp, n_cols, n_groups, \ @@ -176,6 +180,7 @@ void register_sparse_bindings(nb::module_& m) { int n_rows, int n_cols, int n_groups, int n_groups_stats, \ bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovo_streaming_csc_host_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_ref_row_map.data(), h_grp_row_map.data(), \ @@ -216,6 +221,7 @@ void register_sparse_bindings(nb::module_& m) { int n_ref, int n_all_grp, int n_cols, int n_test, \ int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, \ bool compute_nnz, bool compute_sums, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovo_streaming_csr_host_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), n_full_rows, \ h_ref_row_ids.data(), n_ref, h_grp_row_ids.data(), \ diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 923641ce..45088c81 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -354,6 +354,10 @@ def _host_sparse_fn_and_arrays(module, base_name: str, X): # Row/column indices always fit int32 (cells and genes are < 2^31); only the # indptr (cumulative nnz) can need int64. Mirrors the rest of the sparse code. is_i64 = X.indptr.dtype == np.int64 + # The *_f64 binding only changes the host pointer dtype so float64 data can + # be passed without a host-side copy; it still ranks in float32 on-device + # (the kernels cast InT -> float before the segmented sort). See + # _device_sparse_arrays_f32 for why float32 ranking is the uniform design. suffix = "" if is_f64: suffix += "_f64" @@ -365,6 +369,15 @@ def _host_sparse_fn_and_arrays(module, base_name: str, X): def _device_sparse_arrays_f32(X): + """Cast device-sparse arrays for the Wilcoxon kernels. + + Wilcoxon ranking sorts float32 keys on every path -- the sparse fast paths + AND the dense fallback (``_get_dense_column_block_f32``); the CUB segmented + sort is float-keyed throughout. Casting ``X.data`` to float32 here therefore + does not diverge from any float64 ranking path, because there is none. For + count data float32 is exact (integer values < 2**24) and scanpy parity holds + at 1e-13. float64 input is accepted only to spare the caller a pre-cast. + """ data_dtype = np.dtype(X.data.dtype) if data_dtype == np.float32 or data_dtype == np.float64: pass @@ -869,7 +882,9 @@ def _wilcoxon_with_reference( f"full-matrix conversion from {X.format!r}." ) - rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + # zeros, not empty: an all-empty test batch (n_all_grp == 0) + # short-circuits the kernel without writing rank_sums. + rank_sums = cp.zeros((n_test, n_total_genes), dtype=cp.float64) tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) n_groups_stats = n_test + 1 compute_vars = False @@ -1019,7 +1034,9 @@ def _wilcoxon_with_reference( sparse_X.sort_indices() data, indices, indptr = _device_sparse_arrays_f32(sparse_X) offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) - rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + # zeros, not empty: an all-empty test batch (n_all_grp == 0) + # short-circuits the kernel without writing rank_sums. + rank_sums = cp.zeros((n_test, n_total_genes), dtype=cp.float64) tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) if cpsp.isspmatrix_csc(sparse_X): @@ -1110,8 +1127,10 @@ def _wilcoxon_with_reference( ref_f32 = cp.asarray(ref_block, dtype=cp.float32, order="F") grp_f32 = cp.asarray(grp_block, dtype=cp.float32, order="F") - rank_sums = cp.empty((n_test, n_cols), dtype=cp.float64) - tie_corr = cp.empty((n_test, n_cols), dtype=cp.float64) + # zeros/ones, not empty: an all-empty test batch (n_all_grp == 0) + # short-circuits the kernel, leaving these outputs unwritten. + rank_sums = cp.zeros((n_test, n_cols), dtype=cp.float64) + tie_corr = cp.ones((n_test, n_cols), dtype=cp.float64) _wc.ovo_rank_dense_tiered_unsorted_ref( ref_f32, From 00eccb28dd9c8c6e9b43a955f054d6b3cac0e328 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 9 Jun 2026 03:27:58 +0200 Subject: [PATCH 17/36] redo memory allocation --- CMakeLists.txt | 22 ++++-- src/rapids_singlecell/_cuda/__init__.py | 1 + .../wilcoxon_rmm.cu => rmm_scratch.cu} | 8 ++- src/rapids_singlecell/_cuda/rmm_scratch.h | 68 +++++++++++++++++++ .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 60 +--------------- 5 files changed, 92 insertions(+), 67 deletions(-) rename src/rapids_singlecell/_cuda/{wilcoxon/wilcoxon_rmm.cu => rmm_scratch.cu} (83%) create mode 100644 src/rapids_singlecell/_cuda/rmm_scratch.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 1088dbe6..c98171b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -188,6 +188,20 @@ function(add_nb_cuda_module target src) endif() endfunction() +# An RMM-backed nanobind CUDA module: add_nb_cuda_module plus the shared RMM +# scratch allocator (rmm_scratch.cu) and the rmm::rmm link. librmm.so is resolved +# at runtime via the cuML preload (rapids_singlecell/__init__.py imports cuML +# before these extensions, loading librmm into the process), so no INSTALL_RPATH +# is needed. Reusable by any module that needs RMM device scratch. +function(add_rmm_cuda_module target src) + add_nb_cuda_module(${target} ${src}) + if (RSC_BUILD_EXTENSIONS) + target_sources(${target} PRIVATE + src/rapids_singlecell/_cuda/rmm_scratch.cu) + target_link_libraries(${target} PRIVATE rmm::rmm) + endif() +endfunction() + if (RSC_BUILD_EXTENSIONS) # CUDA modules add_nb_cuda_module(_mean_var_cuda src/rapids_singlecell/_cuda/mean_var/mean_var.cu) @@ -217,12 +231,8 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_pseudobulk_cuda src/rapids_singlecell/_cuda/pseudobulk/pseudobulk.cu) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) - add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) - target_sources(_wilcoxon_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) - target_link_libraries(_wilcoxon_cuda PRIVATE rmm::rmm) - add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) - target_sources(_wilcoxon_sparse_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) - target_link_libraries(_wilcoxon_sparse_cuda PRIVATE rmm::rmm) + add_rmm_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + add_rmm_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) add_nb_cuda_module(_rank_stats_cuda src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 886535df..35a6ab06 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -49,6 +49,7 @@ "_spca_cuda", "_wilcoxon_binned_cuda", "_wilcoxon_cuda", + "_wilcoxon_sparse_cuda", ] diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu b/src/rapids_singlecell/_cuda/rmm_scratch.cu similarity index 83% rename from src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu rename to src/rapids_singlecell/_cuda/rmm_scratch.cu index a63c5716..efef484c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu +++ b/src/rapids_singlecell/_cuda/rmm_scratch.cu @@ -4,22 +4,24 @@ #include +#include "rmm_scratch.h" + // Use the resource-ref API (`get_current_device_resource_ref()` + // value-semantic `allocate_sync`/`deallocate_sync`) rather than the raw-pointer // `get_current_device_resource()`. RMM 26.06 removes both the raw-pointer // accessor and `` as it migrates to the cccl // `cuda::mr` resource-concept model. The ref form compiles unchanged from // RMM 25.12 through 26.06 (and onward), so it covers 26.04+. -void* wilcoxon_rmm_allocate(size_t bytes) { +void* rmm_allocate(size_t bytes) { try { return rmm::mr::get_current_device_resource_ref().allocate_sync(bytes); } catch (std::exception const& e) { throw std::runtime_error( - std::string("RMM allocation failed in Wilcoxon scratch (") + + std::string("RMM scratch allocation failed (") + std::to_string(bytes) + " bytes): " + e.what()); } } -void wilcoxon_rmm_deallocate(void* ptr, size_t bytes) { +void rmm_deallocate(void* ptr, size_t bytes) { rmm::mr::get_current_device_resource_ref().deallocate_sync(ptr, bytes); } diff --git a/src/rapids_singlecell/_cuda/rmm_scratch.h b/src/rapids_singlecell/_cuda/rmm_scratch.h new file mode 100644 index 00000000..bf746dfc --- /dev/null +++ b/src/rapids_singlecell/_cuda/rmm_scratch.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include +#include + +// Shared RMM-backed device scratch, usable by any CUDA module that links +// rmm::rmm (see add_rmm_cuda_module in CMakeLists.txt). Allocations come from +// the current RMM device resource, so scratch participates in the same pool as +// CuPy/RAPIDS allocations. +void* rmm_allocate(size_t bytes); +void rmm_deallocate(void* ptr, size_t bytes); + +// --------------------------------------------------------------------------- +// Small allocation pool for temporary CUDA buffers. Frees everything on scope +// exit; reuse a single pool across a kernel pipeline. +// --------------------------------------------------------------------------- +struct RmmScratchPool { + struct Allocation { + void* ptr = nullptr; + size_t bytes = 0; + }; + std::vector bufs; + + ~RmmScratchPool() { + for (Allocation alloc : bufs) { + if (!alloc.ptr) continue; + rmm_deallocate(alloc.ptr, alloc.bytes); + } + } + + template + T* alloc(size_t count) { + if (count == 0) count = 1; + if (count > std::numeric_limits::max() / sizeof(T)) { + throw std::runtime_error("RMM scratch allocation size overflow"); + } + size_t bytes = count * sizeof(T); + void* ptr = rmm_allocate(bytes); + bufs.push_back({ptr, bytes}); + return static_cast(ptr); + } +}; + +// Single RAII RMM device buffer (frees on scope exit). +struct ScopedCudaBuffer { + void* ptr = nullptr; + size_t bytes = 0; + + explicit ScopedCudaBuffer(size_t requested_bytes) { + bytes = requested_bytes == 0 ? 1 : requested_bytes; + ptr = rmm_allocate(bytes); + } + + ~ScopedCudaBuffer() { + if (!ptr) return; + rmm_deallocate(ptr, bytes); + } + + void* data() { + return ptr; + } + + ScopedCudaBuffer(const ScopedCudaBuffer&) = delete; + ScopedCudaBuffer& operator=(const ScopedCudaBuffer&) = delete; +}; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 8a5303db..b0ea1e0e 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -10,10 +10,8 @@ #include -#include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR - -void* wilcoxon_rmm_allocate(size_t bytes); -void wilcoxon_rmm_deallocate(void* ptr, size_t bytes); +#include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR +#include "../rmm_scratch.h" // rmm_allocate, RmmScratchPool, ScopedCudaBuffer constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; @@ -180,60 +178,6 @@ struct HostRegisterGuard { } }; -// --------------------------------------------------------------------------- -// Small allocation pool for temporary CUDA buffers. Uses the current RMM device -// resource so scratch participates in the same pool as CuPy/RAPIDS allocations. -// --------------------------------------------------------------------------- -struct RmmScratchPool { - struct Allocation { - void* ptr = nullptr; - size_t bytes = 0; - }; - std::vector bufs; - - ~RmmScratchPool() { - for (Allocation alloc : bufs) { - if (!alloc.ptr) continue; - wilcoxon_rmm_deallocate(alloc.ptr, alloc.bytes); - } - } - - template - T* alloc(size_t count) { - if (count == 0) count = 1; - if (count > std::numeric_limits::max() / sizeof(T)) { - throw std::runtime_error( - "Wilcoxon scratch allocation size overflow"); - } - size_t bytes = count * sizeof(T); - void* ptr = wilcoxon_rmm_allocate(bytes); - bufs.push_back({ptr, bytes}); - return static_cast(ptr); - } -}; - -struct ScopedCudaBuffer { - void* ptr = nullptr; - size_t bytes = 0; - - explicit ScopedCudaBuffer(size_t requested_bytes) { - bytes = requested_bytes == 0 ? 1 : requested_bytes; - ptr = wilcoxon_rmm_allocate(bytes); - } - - ~ScopedCudaBuffer() { - if (!ptr) return; - wilcoxon_rmm_deallocate(ptr, bytes); - } - - void* data() { - return ptr; - } - - ScopedCudaBuffer(const ScopedCudaBuffer&) = delete; - ScopedCudaBuffer& operator=(const ScopedCudaBuffer&) = delete; -}; - static inline int round_up_to_warp(int n) { int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; From 78365784e24cb9f246d3efdd8ad745a8a8d0bdd0 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 16 Jun 2026 17:52:50 +0200 Subject: [PATCH 18/36] clean up --- docs/installation.md | 15 + src/rapids_singlecell/_cuda/__init__.py | 36 +- .../_cuda/rank_genes/csr_tile_to_dense.cuh | 10 +- .../_cuda/wilcoxon/wilcoxon.cu | 30 +- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 92 ++++++ .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 12 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 13 +- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 2 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 20 +- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 9 +- .../tools/_rank_genes_groups/__init__.py | 15 +- .../tools/_rank_genes_groups/_core.py | 44 --- .../tools/_rank_genes_groups/_utils.py | 7 - .../tools/_rank_genes_groups/_wilcoxon.py | 2 +- .../_rank_genes_groups/_wilcoxon_binned.py | 13 + tests/test_rank_genes_groups_wilcoxon.py | 311 +++++++++++++----- 16 files changed, 427 insertions(+), 204 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 9dd5deb3..35fecdc8 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -65,6 +65,14 @@ pip install rapids-singlecell-cu12 This installs the precompiled CUDA kernels but **not** the RAPIDS stack (cupy, cuml, cudf, etc.). This is the recommended approach for **conda/mamba users** who already have RAPIDS installed in their environment. +```{note} +The compiled kernels (Wilcoxon, GMM, …) link `librmm` / `rapids_logger` at +runtime. These are **required**: they are provided by an existing RAPIDS +conda/mamba environment or by the `[rapids]`/`[rapids-cuXX]` extra below. +Installing the bare `rapids-singlecell-cuXX` wheel into an environment without +RAPIDS raises an `ImportError` when those kernels are first used. +``` + ### Prebuilt wheels with RAPIDS dependencies To also install the RAPIDS stack via pip, use the `rapids` extra. @@ -102,6 +110,13 @@ pip install 'rapids-singlecell[rapids-cu12]' --extra-index-url=https://pypi.nvid ```{note} Building from source requires the CUDA toolkit (nvcc) and CMake >= 3.24 to be available in your environment. The nvcc/CUDAToolkit found during the build should match the RAPIDS/CuPy CUDA major runtime version in or linked to the environment. + +Isolated source builds (the default for `pip install rapids-singlecell` and the +`git+` installs below) pull `librmm-cu12` into the build environment regardless +of your local CUDA major. On a **CUDA 13** system this mismatches the toolkit, so +build inside an environment that already provides a matching `librmm` and pass +`--no-build-isolation` (e.g. `pip install --no-build-isolation "rapids-singlecell @ git+…"`) +so the build uses the environment's `librmm` instead of the cu12 wheel. ``` ### Install from GitHub diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 35a6ab06..625a145f 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -5,8 +5,10 @@ operations. Each module is compiled from CUDA source files and exposed through nanobind bindings. -On systems without compiled extensions (e.g., docs builds), imports resolve -to None so that module-level imports don't raise ImportError. +On systems without compiled extensions (e.g., docs builds), a genuinely absent +module resolves to None so that module-level imports don't raise ImportError. A +module that is present but fails to load (ABI/toolkit mismatch, missing shared +library) is re-raised with context rather than silently swallowed. """ from __future__ import annotations @@ -53,10 +55,38 @@ ] +def _preload_rapids_runtime_libs() -> None: + """Pre-load ``librmm`` / ``rapids_logger`` so the extensions' ``DT_NEEDED`` + soname deps resolve regardless of import order (the editable-install + ``RUNPATH`` is unreliable). Best-effort: absent wheels (docs builds) skip. + """ + for mod in ("librmm", "rapids_logger"): + try: + importlib.import_module(mod).load_library() + except (ImportError, OSError, AttributeError, RuntimeError): + pass + + +_preload_rapids_runtime_libs() + + def __getattr__(name: str): if name in __all__: try: return importlib.import_module(f".{name}", __name__) - except ImportError: + except ModuleNotFoundError: + # Extension genuinely absent (e.g. docs builds, no-GPU installs): + # degrade to None so module-level imports don't raise. return None + except ImportError as exc: + # Extension present but failed to load (ABI/toolkit mismatch, a + # missing shared library, the rmm symbol-ordering issue, ...). + # Surface it with context instead of silently returning None and + # crashing later with a cryptic ``'NoneType' has no attribute ...``. + msg = ( + f"Failed to load compiled CUDA extension {name!r}: {exc}. " + "Ensure a matching rapids-singlecell-cuXX wheel (and librmm) is " + "installed for your CUDA version." + ) + raise ImportError(msg) from exc raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh index e39e32e5..f80ada7b 100644 --- a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh +++ b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh @@ -7,10 +7,9 @@ // (column-major) double buffer. This skips the CSR -> CSC tile rebuild that a // `X[:, lb:ub].tocsc()` densify would do. // -// `out` must be pre-zeroed; the atomicAdd accumulation also makes the result -// correct for uncanonicalized / duplicate column indices (matching scipy's -// sum_duplicates semantics). Output is always double to match the rank_genes -// basic-stats path; the input data dtype is templated. +// `out` must be pre-zeroed; the atomicAdd also sums duplicate column indices +// (like scipy's sum_duplicates) -- bit-identical to a dense materialization for +// canonical CSR. Output is always double; input dtype is templated. template __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, @@ -18,7 +17,8 @@ __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, const TData* __restrict__ data, double* __restrict__ out, int col_lb, int col_ub, int n_cells) { - const int row = blockIdx.x * blockDim.x + threadIdx.x; + const long long row = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (row >= n_cells) { return; } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 4542d634..538006fa 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -33,16 +33,13 @@ static void launch_ovr_rank_dense_streaming( size_t cub_temp_bytes = cub_segmented_sortpairs_temp_bytes(sub_items_i32, sub_batch_cols); - std::vector streams(n_streams); - for (int i = 0; i < n_streams; ++i) { - cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking); - } + ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); - cudaEvent_t inputs_ready; - cudaEventCreateWithFlags(&inputs_ready, cudaEventDisableTiming); - cudaEventRecord(inputs_ready, upstream_stream); + ScopedCudaEvent inputs_ready(cudaEventDisableTiming); + inputs_ready.record(upstream_stream); for (int i = 0; i < n_streams; ++i) { - cudaStreamWaitEvent(streams[i], inputs_ready, 0); + cuda_check(cudaStreamWaitEvent(streams[i], inputs_ready.get(), 0), + "wait on inputs_ready (dense OVR)"); } RmmScratchPool pool; @@ -127,8 +124,6 @@ static void launch_ovr_rank_dense_streaming( cudaGetErrorString(err)); } } - cudaEventDestroy(inputs_ready); - for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); } static void launch_ovo_rank_dense_tiered_unsorted_ref( @@ -179,16 +174,13 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( size_t ref_cub_temp_bytes = cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); - std::vector streams(n_streams); - for (int i = 0; i < n_streams; ++i) { - cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking); - } + ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); - cudaEvent_t inputs_ready; - cudaEventCreateWithFlags(&inputs_ready, cudaEventDisableTiming); - cudaEventRecord(inputs_ready, upstream_stream); + ScopedCudaEvent inputs_ready(cudaEventDisableTiming); + inputs_ready.record(upstream_stream); for (int i = 0; i < n_streams; ++i) { - cudaStreamWaitEvent(streams[i], inputs_ready, 0); + cuda_check(cudaStreamWaitEvent(streams[i], inputs_ready.get(), 0), + "wait on inputs_ready (dense OVO)"); } RmmScratchPool pool; @@ -300,8 +292,6 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( cudaGetErrorString(err)); } } - cudaEventDestroy(inputs_ready); - for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); } template diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index b0ea1e0e..dbda0256 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -178,6 +178,98 @@ struct HostRegisterGuard { } }; +// RAII for CUDA streams/events: reclaim on every path (incl. exception unwind), +// fixing the leak when a throwing call skips a trailing manual destroy. The +// stream dtor SYNCHRONIZES before destroying. Teardown is safe in either +// pool-vs-streams declaration order: on the normal path the launcher syncs all +// streams before any dtor runs, and on exception unwind nothing re-allocates +// the freed (pool-retained) scratch before the streams dtor's sync drains the +// in-flight kernels. Invariant: do not place an allocating RAII member between +// an RmmScratchPool and these guards. +struct ScopedCudaStream { + cudaStream_t stream = nullptr; + + ScopedCudaStream() = default; + explicit ScopedCudaStream(unsigned int flags) { + cuda_check(cudaStreamCreateWithFlags(&stream, flags), + "cudaStreamCreateWithFlags"); + } + ~ScopedCudaStream() { + if (stream) { + cudaStreamSynchronize(stream); // drain before teardown + cudaStreamDestroy(stream); + } + } + operator cudaStream_t() const { + return stream; + } + cudaStream_t get() const { + return stream; + } + ScopedCudaStream(const ScopedCudaStream&) = delete; + ScopedCudaStream& operator=(const ScopedCudaStream&) = delete; +}; + +struct ScopedCudaStreams { + std::vector streams; + + // `flags` is explicit so call sites keep their original stream semantics. + ScopedCudaStreams(int n, unsigned int flags) { + streams.reserve(n > 0 ? (size_t)n : 0); + for (int i = 0; i < n; ++i) { + cudaStream_t s = nullptr; + cudaError_t err = cudaStreamCreateWithFlags(&s, flags); + if (err != cudaSuccess) { + // dtor won't run on ctor throw; reclaim what we made. + for (cudaStream_t prev : streams) { + cudaStreamSynchronize(prev); + cudaStreamDestroy(prev); + } + throw std::runtime_error( + std::string("cudaStreamCreateWithFlags failed: ") + + cudaGetErrorString(err)); + } + streams.push_back(s); + } + } + ~ScopedCudaStreams() { + for (cudaStream_t s : streams) { + if (!s) continue; + cudaStreamSynchronize(s); // drain before teardown + cudaStreamDestroy(s); + } + } + cudaStream_t operator[](int i) const { + return streams[i]; + } + int size() const { + return (int)streams.size(); + } + ScopedCudaStreams(const ScopedCudaStreams&) = delete; + ScopedCudaStreams& operator=(const ScopedCudaStreams&) = delete; +}; + +struct ScopedCudaEvent { + cudaEvent_t event = nullptr; + + ScopedCudaEvent() = default; + explicit ScopedCudaEvent(unsigned int flags) { + cuda_check(cudaEventCreateWithFlags(&event, flags), + "cudaEventCreateWithFlags"); + } + ~ScopedCudaEvent() { + if (event) cudaEventDestroy(event); + } + void record(cudaStream_t stream) { + cuda_check(cudaEventRecord(event, stream), "cudaEventRecord"); + } + cudaEvent_t get() const { + return event; + } + ScopedCudaEvent(const ScopedCudaEvent&) = delete; + ScopedCudaEvent& operator=(const ScopedCudaEvent&) = delete; +}; + static inline int round_up_to_warp(int n) { int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 2e3696fd..8905eb51 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -72,10 +72,8 @@ static void ovo_streaming_csr_impl( cub_temp_bytes = cub_grp_bytes; } - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - cudaStream_t ref_stream; - cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + ScopedCudaStream ref_stream(cudaStreamNonBlocking); int* d_sort_group_ids = nullptr; if (run_huge) { @@ -220,8 +218,6 @@ static void ovo_streaming_csr_impl( cudaGetErrorString(err)); } } - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); - cudaStreamDestroy(ref_stream); } /** @@ -277,8 +273,7 @@ static void ovo_streaming_csc_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); RmmScratchPool pool; int* d_sort_group_ids = nullptr; @@ -402,5 +397,4 @@ static void ovo_streaming_csc_impl( std::string("CUDA error in OVO device CSC streaming: ") + cudaGetErrorString(err)); } - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 8ffe1208..eb9c06b5 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -63,8 +63,7 @@ static void ovo_streaming_csc_host_impl( if (nnz > max_nnz) max_nnz = nnz; } - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); RmmScratchPool pool; @@ -287,8 +286,6 @@ static void ovo_streaming_csc_host_impl( std::string("CUDA error in wilcoxon streaming: ") + cudaGetErrorString(err)); } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } /** @@ -511,8 +508,7 @@ static void ovo_streaming_csr_host_impl( int ref_items_i32 = checked_cub_items(ref_items, "OVO host CSR dense reference cache"); float* d_ref_sorted = pool.alloc(ref_items); - cudaStream_t ref_stream; - cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); + ScopedCudaStream ref_stream(cudaStreamNonBlocking); { ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); @@ -568,7 +564,6 @@ static void ovo_streaming_csr_host_impl( cuda_check(cudaStreamSynchronize(ref_stream), "host CSR OVO ref sort sync"); } // ref scratch drops here - cudaStreamDestroy(ref_stream); // ---- Phase 2: Per-pack streaming ---- auto t1 = make_ovo_tier_plan(h_grp_offsets, n_test); @@ -592,8 +587,7 @@ static void ovo_streaming_csr_host_impl( cub_segmented_sortkeys_temp_bytes(max_sub_items_i32, max_segments); } - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); struct StreamBuf { float* d_grp_data_f32; @@ -771,5 +765,4 @@ static void ovo_streaming_csr_host_impl( std::string("CUDA error in ovo csr host streaming: ") + cudaGetErrorString(err)); } - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index 67bc1dfe..a5bb3f58 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -12,7 +12,7 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, IndptrT re = indptr[row + 1]; for (IndptrT p = rs; p < re; ++p) { int c = (int)indices[p]; - if (c < n_cols) atomicAdd(&col_counts[c], 1u); + if (c >= 0 && c < n_cols) atomicAdd(&col_counts[c], 1u); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 9690e5ce..eb9b7687 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -38,8 +38,7 @@ static void ovr_sparse_csc_host_streaming_impl( max_nnz_i32, sub_batch_cols); } - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); RmmScratchPool pool; int* d_group_codes = pool.alloc(n_rows); @@ -227,8 +226,6 @@ static void ovr_sparse_csc_host_streaming_impl( std::string("CUDA error in sparse host CSC streaming: ") + cudaGetErrorString(err)); } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } // ============================================================================ @@ -329,8 +326,7 @@ static void ovr_sparse_csr_host_streaming_impl( while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) n_streams--; - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); // Pin the source CSR arrays as mapped memory. The scatter kernel reads // only the requested column window from each row. @@ -499,8 +495,6 @@ static void ovr_sparse_csr_host_streaming_impl( std::string("CUDA error in sparse host CSR streaming: ") + cudaGetErrorString(err)); } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } // ============================================================================ @@ -541,8 +535,7 @@ static void ovr_sparse_csc_streaming_impl( cub_segmented_sortpairs_temp_bytes(max_nnz_i32, sub_batch_cols); } - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); int tpb = UTIL_BLOCK_SIZE; bool rank_use_gmem = false; @@ -638,8 +631,6 @@ static void ovr_sparse_csc_streaming_impl( std::string("CUDA error in sparse ovr streaming: ") + cudaGetErrorString(err)); } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } // ============================================================================ @@ -740,8 +731,7 @@ static void ovr_sparse_csr_streaming_impl( while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) n_streams--; - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); int tpb = UTIL_BLOCK_SIZE; int scatter_blocks = (n_rows + tpb - 1) / tpb; @@ -842,6 +832,4 @@ static void ovr_sparse_csr_streaming_impl( std::string("CUDA error in sparse CSR ovr streaming: ") + cudaGetErrorString(err)); } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 6f8d90df..fa5d754b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -114,15 +114,16 @@ __global__ void rank_sums_sparse_ovr_kernel( for (int i = pos_start + threadIdx.x; i < nnz_stored; i += blockDim.x) { int grp = group_codes[si[i]]; if (grp < n_groups) { - atomicAdd(&grp_nz_count[grp * acc_stride], 1.0); + atomicAdd(&grp_nz_count[(size_t)grp * acc_stride], 1.0); } } __syncthreads(); // --- Zero-rank contribution per group --- for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - double n_zero_in_g = group_sizes[g] - grp_nz_count[g * acc_stride]; - grp_sums[g * acc_stride] = n_zero_in_g * zero_avg_rank; + double n_zero_in_g = + group_sizes[g] - grp_nz_count[(size_t)g * acc_stride]; + grp_sums[(size_t)g * acc_stride] = n_zero_in_g * zero_avg_rank; } __syncthreads(); @@ -179,7 +180,7 @@ __global__ void rank_sums_sparse_ovr_kernel( for (int j = i; j < tie_local_end; ++j) { int grp = group_codes[si[j]]; if (grp < n_groups) { - atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + atomicAdd(&grp_sums[(size_t)grp * acc_stride], avg_rank); } } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index a204d73e..0ca84bfa 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -76,14 +76,15 @@ def rank_genes_groups( """ Rank genes for characterizing groups using GPU acceleration. - Expects nonnegative expression data. Log1p/log-normalized data is expected - for biologically meaningful log fold changes; negative values are rejected - for eager in-memory inputs. + Log1p/log-normalized data is expected for biologically meaningful log fold + changes. Complex values are rejected. Sparse inputs with explicit negative + values fall back to the dense full-sort ranking path; dense inputs are + ranked directly and support any sign. .. note:: - **Dask support:** `'t-test'`, `'t-test_overestim_var'`, and - `'wilcoxon_binned'` support Dask arrays. The `'wilcoxon'` and - `'logreg'` methods do not support Dask arrays. + **Dask support:** `'t-test'`, `'t-test_overestim_var'`, + `'wilcoxon_binned'`, and `'logreg'` support Dask arrays. The + `'wilcoxon'` method does not support Dask arrays. Parameters ---------- @@ -140,7 +141,7 @@ def rank_genes_groups( Key from `adata.layers` whose value will be used to perform tests on. chunk_size Number of genes to process at once for `'wilcoxon'` and - `'wilcoxon_binned'`. Default is 128 for `'wilcoxon'`. For + `'wilcoxon_binned'`. Default is 512 for `'wilcoxon'`. For `'wilcoxon_binned'` the default is sized dynamically based on ``n_groups`` and ``n_bins`` to keep histogram memory stable. pre_load diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index b093c168..72c1a32a 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -130,7 +130,6 @@ def __init__( tuple[np.ndarray, cp.ndarray, cp.ndarray, cp.ndarray | None] | None ) = None self._compute_stats_in_chunks: bool = False - self._ref_chunk_computed: set[int] = set() self._score_dtype = np.dtype(np.float32) def _init_stats_arrays(self, n_genes: int) -> None: @@ -314,49 +313,6 @@ def _accumulate_chunk_stats_vs_rest( rest_nnz / rest_sizes[:, None] ) - def _accumulate_chunk_stats_with_ref( - self, - block: cp.ndarray, - start: int, - stop: int, - *, - group_index: int, - group_mask_gpu: cp.ndarray, - n_group: int, - n_ref: int, - ) -> None: - """Compute and store stats for one gene chunk (with reference mode).""" - if not self._compute_stats_in_chunks: - return # Stats already computed via Aggregate - - # Group stats - group_data = block[group_mask_gpu] - group_mean = group_data.mean(axis=0) - self.means[group_index, start:stop] = cp.asnumpy(group_mean) - - if n_group > 1: - group_var = group_data.var(axis=0, ddof=1) - self.vars[group_index, start:stop] = cp.asnumpy(group_var) - - if self.comp_pts: - group_nnz = (group_data != 0).sum(axis=0) - self.pts[group_index, start:stop] = cp.asnumpy(group_nnz / n_group) - - # Reference stats (only compute once, on first non-reference group) - if start not in self._ref_chunk_computed: - self._ref_chunk_computed.add(start) - ref_data = block[~group_mask_gpu] - ref_mean = ref_data.mean(axis=0) - self.means[self.ireference, start:stop] = cp.asnumpy(ref_mean) - - if n_ref > 1: - ref_var = ref_data.var(axis=0, ddof=1) - self.vars[self.ireference, start:stop] = cp.asnumpy(ref_var) - - if self.comp_pts: - ref_nnz = (ref_data != 0).sum(axis=0) - self.pts[self.ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) - def t_test( self, method: Literal["t-test", "t-test_overestim_var"] ) -> list[tuple[int, NDArray, NDArray]]: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 360928e7..2a1c8f4a 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -14,8 +14,6 @@ from numpy.typing import NDArray EPS = 1e-9 -WARP_SIZE = 32 -MAX_THREADS_PER_BLOCK = 512 MIN_GROUP_SIZE_WARNING = 25 @@ -138,11 +136,6 @@ def _select_groups( return groups_order, group_codes, group_sizes -def _round_up_to_warp(n: int) -> int: - """Round up to nearest multiple of WARP_SIZE, capped at MAX_THREADS_PER_BLOCK.""" - return min(MAX_THREADS_PER_BLOCK, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) - - def _choose_chunk_size(requested: int | None) -> int: """Choose chunk size for gene processing.""" if requested is not None: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 45088c81..d6fc36ac 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -1033,7 +1033,7 @@ def _wilcoxon_with_reference( sparse_X = sparse_X.copy() sparse_X.sort_indices() data, indices, indptr = _device_sparse_arrays_f32(sparse_X) - offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) + # offsets_gpu (built once above as int32) is reused here. # zeros, not empty: an all-empty test batch (n_all_grp == 0) # short-circuits the kernel without writing rank_sums. rank_sums = cp.zeros((n_test, n_total_genes), dtype=cp.float64) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index d5f4ed0d..f956f9db 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -183,6 +183,19 @@ def wilcoxon_binned( if bin_range is None: bin_range = "log1p" if isinstance(X, DaskArray) else "auto" + # The fixed log1p [0, 15] range assumes nonnegative data. For signed sparse + # input the dense fallback would clamp negatives into the lowest bin and + # silently produce wrong rank sums, so switch to the data-driven 'auto' + # range (which spans the true [min, max], including negatives). + if rg._sparse_negative_fallback and bin_range == "log1p": + warnings.warn( + "bin_range='log1p' is invalid for sparse input with negative values " + "(the fixed [0, 15] range would clamp them); using bin_range='auto'.", + RuntimeWarning, + stacklevel=4, + ) + bin_range = "auto" + # Prepare GPU arrays and bin arithmetic if bin_range == "auto": bin_low, bin_high = _data_range(X) diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 5abcda84..fb18ae36 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -101,84 +101,46 @@ def test_rank_genes_groups_complex_values_raise(fmt): rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) -def test_device_sparse_int64_indptr_selects_i64_kernel(): - from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _device_sparse_arrays_f32, - _device_sparse_fn, - ) - - class FakeSparse: - data = cp.asarray([1.0], dtype=cp.float32) - indices = cp.asarray([0], dtype=cp.int32) - indptr = cp.asarray([0, np.iinfo(np.int32).max + 1], dtype=cp.int64) - - # int64 indptr is preserved (no truncation, no dense fallback); the row - # indices stay int32 because cells/genes are always < 2^31. - data, indices, indptr = _device_sparse_arrays_f32(FakeSparse()) - assert indptr.dtype == cp.int64 - assert indices.dtype == cp.int32 - assert data.dtype == cp.float32 - - # Dispatch routes an int64 indptr to the *_i64 kernel binding, int32 to base. - assert ( - _device_sparse_fn(_wcs, "ovr_sparse_csc_device", indptr) - is _wcs.ovr_sparse_csc_device_i64 - ) - assert ( - _device_sparse_fn(_wcs, "ovr_sparse_csc_device", indices) - is _wcs.ovr_sparse_csc_device - ) - - -@pytest.mark.parametrize("layout", ["csc", "csr"]) -def test_device_ovr_sparse_i64_indptr_matches_i32(layout): - # cupyx coerces small matrices to int32 indptr, so int64 support is only - # reachable for nnz > 2^31. Exercise the int64-templated kernels directly - # with a hand-built int64 indptr and assert bit-parity with the int32 path. - from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs - +@pytest.mark.parametrize("layout", ["csr", "csc"]) +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_device_sparse_int64_indptr_matches_scanpy(layout, reference): + # Real int64 indptr only occurs at nnz > 2^31 (unallocatable in CI). cupy + # >= 14.1 preserves an explicitly promoted int64 indptr, so a small matrix + # promoted to int64 drives the *_i64 device kernels through the public API. rng = np.random.default_rng(0) - n_rows, n_cols, n_groups = 120, 10, 4 - dense = np.abs(rng.standard_normal((n_rows, n_cols))).astype(np.float32) - dense[dense < 0.6] = 0.0 - mat = sp.csc_matrix(dense) if layout == "csc" else sp.csr_matrix(dense) - mat.sort_indices() - gcodes = rng.integers(0, n_groups, n_rows).astype(np.int32) - gsizes = np.bincount(gcodes, minlength=n_groups).astype(np.float64) - - data = cp.asarray(mat.data, dtype=cp.float32) - indices = cp.asarray(mat.indices, dtype=cp.int32) - g = cp.asarray(gcodes) - gs = cp.asarray(gsizes) - base = getattr(_wcs, f"ovr_sparse_{layout}_device") - i64 = getattr(_wcs, f"ovr_sparse_{layout}_device_i64") - - def run(indptr_dtype, fn): - indptr = cp.asarray(mat.indptr, dtype=indptr_dtype) - rs = cp.empty((n_groups, n_cols), dtype=cp.float64) - tc = cp.ones(n_cols, dtype=cp.float64) - fn( - data, - indices, - indptr, - g, - gs, - rs, - tc, - n_rows=n_rows, - n_cols=n_cols, - n_groups=n_groups, - compute_tie_corr=True, - sub_batch_cols=64, - ) - cp.cuda.get_current_stream().synchronize() - return rs.get(), tc.get() - - rs32, tc32 = run(cp.int32, base) - rs64, tc64 = run(cp.int64, i64) - np.testing.assert_array_equal(rs32, rs64) - np.testing.assert_array_equal(tc32, tc64) + dense = np.abs(rng.standard_normal((150, 8))).astype(np.float32) + dense[dense < 0.5] = 0.0 + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(150)])}) + var = pd.DataFrame(index=[f"g{j}" for j in range(8)]) + + ctor = cpsp.csr_matrix if layout == "csr" else cpsp.csc_matrix + mat = ctor(cp.asarray(dense)) + mat.indptr = mat.indptr.astype(cp.int64) + mat.indices = mat.indices.astype(cp.int64) + assert mat.indptr.dtype == cp.int64 + + adata = sc.AnnData(X=mat, obs=obs.copy(), var=var.copy()) + adata_cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + "n_genes": 8, + } + rsc.tl.rank_genes_groups(adata, "group", **kw) + sc.tl.rank_genes_groups(adata_cpu, "group", **kw) + g = adata.uns["rank_genes_groups"] + c = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals", "pvals_adj"): + for grp in g[field].dtype.names: + np.testing.assert_allclose( + np.asarray(g[field][grp], dtype=float), + np.asarray(c[field][grp], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) def test_rank_genes_groups_structured_results_get_df_and_h5ad_match_scanpy(tmp_path): @@ -1341,3 +1303,198 @@ def test_sparse_matches_dense(self, perturbation_adata, sparse): np.testing.assert_array_equal( dense_df["pvals"].values, sparse_df["pvals"].values ) + + +def _make_count_adata(seed=0, n_obs=120, n_genes=6, n_groups=3): + # Integer-valued counts as float64: float32-exact, zeros create ties. + rng = np.random.default_rng(seed) + X = rng.integers(0, 8, size=(n_obs, n_genes)).astype(np.float64) + X[X < 2] = 0.0 # extra zeros -> implicit-zero tie blocks + labels = np.array([f"{i % n_groups}" for i in range(n_obs)]) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{j}" for j in range(n_genes)]) + adata = sc.AnnData(X=X, obs=obs, var=var) + adata.uns["log1p"] = {"base": None} + return adata + + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_wilcoxon_host_sparse_float64_data_matches_scanpy(fmt, reference): + # float64 host-sparse data exercises the *_f64 kernel bindings. + adata = _make_count_adata(seed=3) + adata_cpu = adata.copy() + mat = sp.csr_matrix(adata.X) if fmt == "scipy_csr" else sp.csc_matrix(adata.X) + assert mat.dtype == np.float64 + adata.X = mat + + kw = { + "groupby": "group", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + "n_genes": adata.n_vars, + } + rsc.tl.rank_genes_groups(adata, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + g = adata.uns["rank_genes_groups"] + c = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals", "pvals_adj"): + for grp in g[field].dtype.names: + np.testing.assert_allclose( + np.asarray(g[field][grp], dtype=float), + np.asarray(c[field][grp], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +@pytest.mark.parametrize("data_dtype", [np.int32, np.int64, np.uint16, bool]) +def test_wilcoxon_sparse_integer_bool_data_matches_float32(fmt, data_dtype): + # Integer/bool data hits the cast-to-float32 branch; must match float32. + rng = np.random.default_rng(5) + n_obs, n_genes = 100, 6 + counts = rng.integers(0, 5, size=(n_obs, n_genes)) + if data_dtype is bool: + counts = counts > 2 + typed = counts.astype(data_dtype) + f32 = counts.astype(np.float32) + labels = np.array([f"{i % 3}" for i in range(n_obs)]) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{j}" for j in range(n_genes)]) + + def run(arr): + adata = sc.AnnData(X=_to_format(arr, fmt), obs=obs.copy(), var=var.copy()) + adata.uns["log1p"] = {"base": None} + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + n_genes=n_genes, + ) + return adata.uns["rank_genes_groups"] + + r_typed = run(typed) + r_f32 = run(f32) + for grp in r_typed["scores"].dtype.names: + np.testing.assert_array_equal( + np.asarray(r_typed["scores"][grp], dtype=float), + np.asarray(r_f32["scores"][grp], dtype=float), + ) + + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +def test_wilcoxon_sparse_float16_data_raises(fmt): + # Unsupported float16 sparse data is rejected with a TypeError. + rng = np.random.default_rng(0) + dense = np.abs(rng.standard_normal((40, 4))).astype(np.float32) + mat = sp.csr_matrix(dense) if fmt == "scipy_csr" else sp.csc_matrix(dense) + mat.data = mat.data.astype(np.float16) + assert mat.data.dtype == np.float16 + adata = sc.AnnData( + X=mat, + obs=pd.DataFrame({"group": pd.Categorical([f"{i % 2}" for i in range(40)])}), + var=pd.DataFrame(index=[f"g{j}" for j in range(4)]), + ) + with pytest.raises(TypeError, match="float32"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +@pytest.mark.parametrize("reference", ["rest", "b"]) +@pytest.mark.parametrize("fmt", ["scipy_csc", "cupy_csc", "cupy_dense"]) +def test_rank_genes_groups_wilcoxon_return_u_values_more_formats(reference, fmt): + # U-value + continuity epilogue on CSC (host/device) and device-dense. + X = np.array( + [ + [5.0, 0.0, 1.0, 2.0], + [4.0, 0.0, 1.0, 2.0], + [1.0, 3.0, 2.0, 2.0], + [0.0, 2.0, 2.0, 2.0], + [2.0, 1.0, 0.0, 3.0], + [3.0, 1.0, 0.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "b", "b", "c", "c"]) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference=reference, + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + result = adata.uns["rank_genes_groups"] + assert result["params"]["return_u_values"] is True + assert result["scores"].dtype["a"] == np.dtype("float64") + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + mask_group = labels == "a" + mask_ref = labels != "a" if reference == "rest" else labels == reference + expected = np.array( + [ + mannwhitneyu( + X[mask_group, gene], + X[mask_ref, gene], + alternative="two-sided", + ).statistic + for gene in range(X.shape[1]) + ], + dtype=np.float64, + ) + gene_to_idx = {name: idx for idx, name in enumerate(adata.var_names)} + expected_sorted = np.array([expected[gene_to_idx[name]] for name in df["names"]]) + np.testing.assert_allclose(df["scores"].to_numpy(), expected_sorted) + + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_rank_genes_groups_sparse_negative_values_fallback_ovo(fmt): + # Sparse-negative dense fallback on the OVO (with-reference) path. + X = np.array( + [ + [-1.0, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + [-2.0, 1.0, 0.0], + [1.0, 0.0, 3.0], + ], + dtype=np.float64, + ) + obs = pd.DataFrame({"group": pd.Categorical(list("aaabbb"), categories=["a", "b"])}) + var = pd.DataFrame(index=["g0", "g1", "g2"]) + + sparse_adata = sc.AnnData(X=_to_format(X, fmt), obs=obs.copy(), var=var.copy()) + dense_fmt = "cupy_dense" if fmt.startswith("cupy") else "numpy_dense" + dense_adata = sc.AnnData(X=_to_format(X, dense_fmt), obs=obs.copy(), var=var.copy()) + + kw = {"groupby": "group", "method": "wilcoxon", "use_raw": False, "reference": "b"} + rsc.tl.rank_genes_groups(sparse_adata, **kw) + rsc.tl.rank_genes_groups(dense_adata, **kw) + + sp_scores = sparse_adata.uns["rank_genes_groups"]["scores"] + dn_scores = dense_adata.uns["rank_genes_groups"]["scores"] + for group in sp_scores.dtype.names: + np.testing.assert_allclose( + np.asarray(sp_scores[group], dtype=float), + np.asarray(dn_scores[group], dtype=float), + rtol=1e-13, + atol=1e-13, + ) From a65155aa67e25929f9365f9ea57adcd675e37a92 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 16 Jun 2026 18:33:59 +0200 Subject: [PATCH 19/36] more cleanup --- .../_cuda/wilcoxon/wilcoxon.cu | 7 +- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 18 ++-- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 4 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 45 ++++------ .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 60 ++++--------- .../_cuda/wilcoxon/wilcoxon_sparse.cu | 66 ++++++-------- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 50 ++++------- .../tools/_rank_genes_groups/_wilcoxon.py | 85 +++---------------- 8 files changed, 98 insertions(+), 237 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 538006fa..03a0f0da 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -33,6 +33,8 @@ static void launch_ovr_rank_dense_streaming( size_t cub_temp_bytes = cub_segmented_sortpairs_temp_bytes(sub_items_i32, sub_batch_cols); + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); ScopedCudaEvent inputs_ready(cudaEventDisableTiming); @@ -42,7 +44,6 @@ static void launch_ovr_rank_dense_streaming( "wait on inputs_ready (dense OVR)"); } - RmmScratchPool pool; struct StreamBuf { float* keys_out; int* vals_in; @@ -174,6 +175,8 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( size_t ref_cub_temp_bytes = cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); ScopedCudaEvent inputs_ready(cudaEventDisableTiming); @@ -182,8 +185,6 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( cuda_check(cudaStreamWaitEvent(streams[i], inputs_ready.get(), 0), "wait on inputs_ready (dense OVO)"); } - - RmmScratchPool pool; int* d_sort_group_ids = nullptr; if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index dbda0256..4d8402df 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -180,12 +180,10 @@ struct HostRegisterGuard { // RAII for CUDA streams/events: reclaim on every path (incl. exception unwind), // fixing the leak when a throwing call skips a trailing manual destroy. The -// stream dtor SYNCHRONIZES before destroying. Teardown is safe in either -// pool-vs-streams declaration order: on the normal path the launcher syncs all -// streams before any dtor runs, and on exception unwind nothing re-allocates -// the freed (pool-retained) scratch before the streams dtor's sync drains the -// in-flight kernels. Invariant: do not place an allocating RAII member between -// an RmmScratchPool and these guards. +// stream dtor SYNCHRONIZES before destroying. Convention: declare the +// RmmScratchPool BEFORE these guards so the streams (destroyed first) drain +// their in-flight kernels before the pool (destroyed last) frees the scratch +// those kernels read -- safe on the normal and exception-unwind paths alike. struct ScopedCudaStream { cudaStream_t stream = nullptr; @@ -342,9 +340,8 @@ __global__ void csr_gather_cast_accumulate_mapped_kernel( const int* __restrict__ d_row_ids, const int* __restrict__ d_out_indptr, const int* __restrict__ d_stats_codes, int fixed_slot, float* __restrict__ d_out_data_f32, int* __restrict__ d_out_indices, - double* __restrict__ group_sums, double* __restrict__ group_sq_sums, - double* __restrict__ group_nnz, int n_target_rows, int n_cols, - int n_groups_stats, bool compute_sums, bool compute_sq_sums, + double* __restrict__ group_sums, double* __restrict__ group_nnz, + int n_target_rows, int n_cols, int n_groups_stats, bool compute_sums, bool compute_nnz) { int r = blockIdx.x; if (r >= n_target_rows) return; @@ -365,9 +362,6 @@ __global__ void csr_gather_cast_accumulate_mapped_kernel( if (compute_sums) { atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); } - if (compute_sq_sums) { - atomicAdd(&group_sq_sums[(size_t)slot * n_cols + c], v * v); - } if (compute_nnz && v != 0.0) { atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 8905eb51..c325fc42 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -273,9 +273,9 @@ static void ovo_streaming_csc_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - ScopedCudaStreams streams(n_streams, cudaStreamDefault); - + // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamDefault); int* d_sort_group_ids = nullptr; if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index eb9c06b5..2b82aed9 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -12,10 +12,9 @@ static void ovo_streaming_csc_host_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, const int* h_ref_row_map, const int* h_grp_row_map, const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, - double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, - int n_groups, int n_groups_stats, bool compute_tie_corr, - bool compute_sq_sums, bool compute_nnz, int sub_batch_cols) { + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, int n_ref, + int n_all_grp, int n_rows, int n_cols, int n_groups, int n_groups_stats, + bool compute_tie_corr, bool compute_nnz, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch from host offsets ---- @@ -63,9 +62,9 @@ static void ovo_streaming_csc_host_impl( if (nnz > max_nnz) max_nnz = nnz; } - ScopedCudaStreams streams(n_streams, cudaStreamDefault); - + // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamDefault); int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); @@ -123,7 +122,6 @@ static void ovo_streaming_csc_host_impl( double* d_rank_sums; double* d_tie_corr; double* d_group_sums; - double* d_group_sq_sums; double* d_group_nnz; }; std::vector bufs(n_streams); @@ -147,8 +145,6 @@ static void ovo_streaming_csc_host_impl( pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].d_group_sums = pool.alloc((size_t)n_groups_stats * sub_batch_cols); - bufs[s].d_group_sq_sums = pool.alloc( - compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); bufs[s].d_group_nnz = pool.alloc( compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); if (run_huge) { @@ -168,8 +164,8 @@ static void ovo_streaming_csc_host_impl( int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); bool cast_use_gmem = false; - size_t smem_cast = cast_accumulate_smem_config( - n_groups_stats, compute_sq_sums, compute_nnz, cast_use_gmem); + size_t smem_cast = + cast_accumulate_smem_config(n_groups_stats, compute_nnz, cast_use_gmem); // Pin only the sparse input arrays; outputs live on the device. size_t total_nnz = (size_t)h_indptr[n_cols]; @@ -208,9 +204,9 @@ static void ovo_streaming_csc_host_impl( // ---- Cast to float32 for sort + accumulate stats in float64 ---- launch_ovr_cast_and_accumulate_sparse( buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, - buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_sq_sums, - buf.d_group_nnz, sb_cols, n_groups_stats, compute_sq_sums, - compute_nnz, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_nnz, + sb_cols, n_groups_stats, compute_nnz, UTIL_BLOCK_SIZE, smem_cast, + cast_use_gmem, stream); // ---- Extract ref from CSC via row_map, sort ---- cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), @@ -262,12 +258,6 @@ static void ovo_streaming_csc_host_impl( buf.d_group_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups_stats, cudaMemcpyDeviceToDevice, stream); - if (compute_sq_sums) { - cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), - buf.d_group_sq_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); - } if (compute_nnz) { cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), buf.d_group_nnz, sb_cols * sizeof(double), @@ -316,8 +306,7 @@ static void ovo_streaming_csr_host_impl( int n_full_rows, const int* h_ref_row_ids, int n_ref, const int* h_grp_row_ids, const int* h_grp_offsets, int n_all_grp, int n_test, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, - double* d_group_sq_sums, double* d_group_nnz, int n_cols, - int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, + double* d_group_nnz, int n_cols, int n_groups_stats, bool compute_tie_corr, bool compute_nnz, bool compute_sums, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; @@ -431,10 +420,6 @@ static void ovo_streaming_csr_host_impl( cudaMemsetAsync(d_group_sums, 0, (size_t)n_groups_stats * n_cols * sizeof(double)); } - if (compute_sq_sums) { - cudaMemsetAsync(d_group_sq_sums, 0, - (size_t)n_groups_stats * n_cols * sizeof(double)); - } if (compute_nnz) { cudaMemsetAsync(d_group_nnz, 0, (size_t)n_groups_stats * n_cols * sizeof(double)); @@ -534,8 +519,8 @@ static void ovo_streaming_csr_host_impl( d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, d_ref_indptr, /*d_stats_codes=*/nullptr, /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, - d_group_sums, d_group_sq_sums, d_group_nnz, n_ref, n_cols, - n_groups_stats, compute_sums, compute_sq_sums, compute_nnz); + d_group_sums, d_group_nnz, n_ref, n_cols, n_groups_stats, + compute_sums, compute_nnz); CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); } @@ -707,8 +692,8 @@ static void ovo_streaming_csr_host_impl( d_grp_row_ids + row_start, buf.d_grp_indptr, buf.d_pack_stats_codes, /*fixed_slot=*/-1, buf.d_grp_data_f32, buf.d_grp_indices, d_group_sums, - d_group_sq_sums, d_group_nnz, pack_rows, n_cols, - n_groups_stats, compute_sums, compute_sq_sums, compute_nnz); + d_group_nnz, pack_rows, n_cols, n_groups_stats, + compute_sums, compute_nnz); CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index eb9b7687..1c71bd0f 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -11,9 +11,8 @@ template static void ovr_sparse_csc_host_streaming_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, - double* d_group_nnz, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool compute_nnz, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; @@ -38,9 +37,9 @@ static void ovr_sparse_csc_host_streaming_impl( max_nnz_i32, sub_batch_cols); } - ScopedCudaStreams streams(n_streams, cudaStreamDefault); - + // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamDefault); int* d_group_codes = pool.alloc(n_rows); double* d_group_sizes = pool.alloc(n_groups); struct StreamBuf { @@ -54,7 +53,6 @@ static void ovr_sparse_csc_host_streaming_impl( double* d_rank_sums; double* d_tie_corr; double* d_group_sums; - double* d_group_sq_sums; double* d_group_nnz; double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem }; @@ -72,10 +70,6 @@ static void ovr_sparse_csc_host_streaming_impl( bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); bufs[s].d_group_sums = pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_group_sq_sums = - compute_sq_sums - ? pool.alloc((size_t)n_groups * sub_batch_cols) - : nullptr; bufs[s].d_group_nnz = compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; @@ -111,8 +105,8 @@ static void ovr_sparse_csc_host_streaming_impl( bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); bool cast_use_gmem = false; - size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, - compute_nnz, cast_use_gmem); + size_t smem_cast = + cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); // In gmem mode the sparse rank kernel accumulates into rank_sums directly // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). @@ -165,9 +159,8 @@ static void ovr_sparse_csc_host_streaming_impl( // Cast to float32 for sort + accumulate stats in float64 launch_ovr_cast_and_accumulate_sparse( buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, - buf.d_seg_offsets, d_group_codes, buf.d_group_sums, - buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups, - compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + buf.d_seg_offsets, d_group_codes, buf.d_group_sums, buf.d_group_nnz, + sb_cols, n_groups, compute_nnz, tpb, smem_cast, cast_use_gmem, stream); // CUB sort only stored nonzeros (float32 keys) @@ -202,12 +195,6 @@ static void ovr_sparse_csc_host_streaming_impl( buf.d_group_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, cudaMemcpyDeviceToDevice, stream); - if (compute_sq_sums) { - cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), - buf.d_group_sq_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } if (compute_nnz) { cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), buf.d_group_nnz, sb_cols * sizeof(double), @@ -244,9 +231,8 @@ template static void ovr_sparse_csr_host_streaming_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, - double* d_group_nnz, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool compute_nnz, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; @@ -298,8 +284,8 @@ static void ovr_sparse_csr_host_streaming_impl( bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); bool cast_use_gmem = false; - size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, - compute_nnz, cast_use_gmem); + size_t smem_cast = + cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); int n_streams = N_STREAMS; if (n_batches < n_streams) n_streams = n_batches; @@ -309,9 +295,6 @@ static void ovr_sparse_csr_host_streaming_impl( (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + sub_batch_cols * sizeof(double); - if (compute_sq_sums) { - per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); - } if (compute_nnz) { per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); } @@ -377,7 +360,6 @@ static void ovr_sparse_csr_host_streaming_impl( double* sub_rank_sums; double* sub_tie_corr; double* sub_group_sums; - double* sub_group_sq_sums; double* sub_group_nnz; double* d_nz_scratch; }; @@ -396,10 +378,6 @@ static void ovr_sparse_csr_host_streaming_impl( bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); bufs[s].sub_group_sums = pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_group_sq_sums = - compute_sq_sums - ? pool.alloc((size_t)n_groups * sub_batch_cols) - : nullptr; bufs[s].sub_group_nnz = compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; @@ -439,9 +417,8 @@ static void ovr_sparse_csr_host_streaming_impl( launch_ovr_cast_and_accumulate_sparse( buf.csc_vals_orig, buf.csc_vals_f32, buf.csc_row_idx, buf.col_offsets, d_group_codes, buf.sub_group_sums, - buf.sub_group_sq_sums, buf.sub_group_nnz, sb_cols, n_groups, - compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, - stream); + buf.sub_group_nnz, sb_cols, n_groups, compute_nnz, tpb, smem_cast, + cast_use_gmem, stream); if (batch_nnz > 0) { size_t temp = cub_temp_bytes; @@ -472,12 +449,6 @@ static void ovr_sparse_csr_host_streaming_impl( buf.sub_group_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, cudaMemcpyDeviceToDevice, stream); - if (compute_sq_sums) { - cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), - buf.sub_group_sq_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } if (compute_nnz) { cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), buf.sub_group_nnz, sb_cols * sizeof(double), @@ -535,13 +506,14 @@ static void ovr_sparse_csc_streaming_impl( cub_segmented_sortpairs_temp_bytes(max_nnz_i32, sub_batch_cols); } + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; ScopedCudaStreams streams(n_streams, cudaStreamDefault); int tpb = UTIL_BLOCK_SIZE; bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); - RmmScratchPool pool; struct StreamBuf { float* keys_out; int* vals_out; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index 4ac0b62b..9feeae0b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -61,24 +61,21 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_rank_sums, \ gpu_array_c d_tie_corr, \ gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ gpu_array_c d_group_nnz, int n_rows, int n_cols, \ - int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ - bool compute_nnz, int sub_batch_cols) { \ + int n_groups, bool compute_tie_corr, bool compute_nnz, \ + int sub_batch_cols) { \ if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovr_sparse_csc_host_streaming_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_group_codes.data(), h_group_sizes.data(), \ d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ - n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ - sub_batch_cols); \ + d_group_nnz.data(), n_rows, n_cols, n_groups, \ + compute_tie_corr, compute_nnz, sub_batch_cols); \ }, \ "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ - "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ - "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "compute_tie_corr"_a, "compute_nnz"_a = true, \ "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int, int); @@ -100,24 +97,21 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_rank_sums, \ gpu_array_c d_tie_corr, \ gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ gpu_array_c d_group_nnz, int n_rows, int n_cols, \ - int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ - bool compute_nnz, int sub_batch_cols) { \ + int n_groups, bool compute_tie_corr, bool compute_nnz, \ + int sub_batch_cols) { \ if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovr_sparse_csr_host_streaming_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_group_codes.data(), h_group_sizes.data(), \ d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ - n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ - sub_batch_cols); \ + d_group_nnz.data(), n_rows, n_cols, n_groups, \ + compute_tie_corr, compute_nnz, sub_batch_cols); \ }, \ "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ - "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ - "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "compute_tie_corr"_a, "compute_nnz"_a = true, \ "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); @@ -175,27 +169,24 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_rank_sums, \ gpu_array_c d_tie_corr, \ gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ int n_rows, int n_cols, int n_groups, int n_groups_stats, \ - bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ - int sub_batch_cols) { \ + bool compute_tie_corr, bool compute_nnz, int sub_batch_cols) { \ if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovo_streaming_csc_host_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_ref_row_map.data(), h_grp_row_map.data(), \ h_grp_offsets.data(), h_stats_codes.data(), \ d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ - n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ - compute_sq_sums, compute_nnz, sub_batch_cols); \ + d_group_nnz.data(), n_ref, n_all_grp, n_rows, n_cols, \ + n_groups, n_groups_stats, compute_tie_corr, compute_nnz, \ + sub_batch_cols); \ }, \ "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ - "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ - "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ - "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, "d_group_nnz"_a, \ + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, \ + "n_groups"_a, "n_groups_stats"_a, "compute_tie_corr"_a, \ "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); @@ -216,27 +207,24 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_rank_sums, \ gpu_array_c d_tie_corr, \ gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ gpu_array_c d_group_nnz, int n_full_rows, \ int n_ref, int n_all_grp, int n_cols, int n_test, \ - int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, \ - bool compute_nnz, bool compute_sums, int sub_batch_cols) { \ + int n_groups_stats, bool compute_tie_corr, bool compute_nnz, \ + bool compute_sums, int sub_batch_cols) { \ if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovo_streaming_csr_host_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), n_full_rows, \ h_ref_row_ids.data(), n_ref, h_grp_row_ids.data(), \ h_grp_offsets.data(), n_all_grp, n_test, d_rank_sums.data(), \ - d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_cols, \ - n_groups_stats, compute_tie_corr, compute_sq_sums, \ - compute_nnz, compute_sums, sub_batch_cols); \ + d_tie_corr.data(), d_group_sums.data(), d_group_nnz.data(), \ + n_cols, n_groups_stats, compute_tie_corr, compute_nnz, \ + compute_sums, sub_batch_cols); \ }, \ "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ "h_grp_row_ids"_a, "h_grp_offsets"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ - "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ - "n_full_rows"_a, "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_test"_a, \ - "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ - "compute_nnz"_a = true, "compute_sums"_a = true, \ + "d_group_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_full_rows"_a, \ + "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_test"_a, "n_groups_stats"_a, \ + "compute_tie_corr"_a, "compute_nnz"_a = true, "compute_sums"_a = true, \ "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index fa5d754b..a28d1c6c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -273,9 +273,9 @@ static inline void launch_ovr_sparse_rank( // global memory. Same large-n_groups workloads that drive // sparse_ovr_smem_config to gmem also drive this one; both fallbacks are // load-bearing, not dead. -static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, - bool compute_nnz, bool& use_gmem) { - int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); +static size_t cast_accumulate_smem_config(int n_groups, bool compute_nnz, + bool& use_gmem) { + int n_arrays = 1 + (compute_nnz ? 1 : 0); size_t need = (size_t)n_arrays * n_groups * sizeof(double); if (need <= wilcoxon_max_smem_per_block()) { use_gmem = false; @@ -302,8 +302,7 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, const int* __restrict__ group_codes, double* __restrict__ group_sums, - double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, - int sb_cols, int n_groups, bool compute_sq_sums = true, + double* __restrict__ group_nnz, int sb_cols, int n_groups, bool compute_nnz = true) { int col = blockIdx.x; if (col >= sb_cols) return; @@ -312,19 +311,13 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( int seg_end = col_seg_offsets[col + 1]; // Packed layout matching cast_accumulate_smem_config, which sizes the - // dynamic smem as (1 + compute_sq_sums + compute_nnz) * n_groups doubles. - // s_nnz must follow only the arrays that are actually present: using a - // fixed 2*n_groups offset over-runs the allocation when sq-sums is off but - // nnz is on (the host OVR pts path), corrupting/faulting at larger - // n_groups. + // dynamic smem as (1 + compute_nnz) * n_groups doubles. extern __shared__ double smem[]; double* s_sum = smem; - double* s_sq = smem + n_groups; - double* s_nnz = smem + (compute_sq_sums ? 2 : 1) * n_groups; + double* s_nnz = smem + n_groups; for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { s_sum[g] = 0.0; - if (compute_sq_sums) s_sq[g] = 0.0; if (compute_nnz) s_nnz[g] = 0.0; } __syncthreads(); @@ -337,7 +330,6 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( int g = group_codes[row]; if (g < n_groups) { atomicAdd(&s_sum[g], v); - if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); } } @@ -345,9 +337,6 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { group_sums[(size_t)g * sb_cols + col] = s_sum[g]; - if (compute_sq_sums) { - group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; - } if (compute_nnz) { group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; } @@ -364,8 +353,7 @@ __global__ void ovr_cast_and_accumulate_sparse_global_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, const int* __restrict__ group_codes, double* __restrict__ group_sums, - double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, - int sb_cols, int n_groups, bool compute_sq_sums = true, + double* __restrict__ group_nnz, int sb_cols, int n_groups, bool compute_nnz = true) { int col = blockIdx.x; if (col >= sb_cols) return; @@ -381,9 +369,6 @@ __global__ void ovr_cast_and_accumulate_sparse_global_kernel( int g = group_codes[row]; if (g < n_groups) { atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); - if (compute_sq_sums) { - atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); - } if (compute_nnz && v != 0.0) { atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); } @@ -395,32 +380,27 @@ template static void launch_ovr_cast_and_accumulate_sparse( const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, const int* d_col_offsets, const int* d_group_codes, double* d_group_sums, - double* d_group_sq_sums, double* d_group_nnz, int sb_cols, int n_groups, - bool compute_sq_sums, bool compute_nnz, int tpb, size_t smem_cast, - bool use_gmem, cudaStream_t stream) { + double* d_group_nnz, int sb_cols, int n_groups, bool compute_nnz, int tpb, + size_t smem_cast, bool use_gmem, cudaStream_t stream) { if (use_gmem) { size_t stats_items = (size_t)n_groups * sb_cols; cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); - if (compute_sq_sums) { - cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), - stream); - } if (compute_nnz) { cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), stream); } ovr_cast_and_accumulate_sparse_global_kernel - <<>>( - d_data_orig, d_data_f32, d_indices, d_col_offsets, - d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, - sb_cols, n_groups, compute_sq_sums, compute_nnz); + <<>>(d_data_orig, d_data_f32, d_indices, + d_col_offsets, d_group_codes, + d_group_sums, d_group_nnz, sb_cols, + n_groups, compute_nnz); CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_global_kernel); } else { ovr_cast_and_accumulate_sparse_kernel <<>>( d_data_orig, d_data_f32, d_indices, d_col_offsets, - d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, - sb_cols, n_groups, compute_sq_sums, compute_nnz); + d_group_codes, d_group_sums, d_group_nnz, sb_cols, n_groups, + compute_nnz); CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); } } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index d6fc36ac..159b7fe0 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -151,24 +151,20 @@ def _fill_ovo_chunk_stats( def _fill_basic_stats_from_accumulators( rg: _RankGenes, group_sums: cp.ndarray, - group_sq_sums: cp.ndarray, group_nnz: cp.ndarray, group_sizes: np.ndarray, *, n_cells: int, - compute_vars: bool, total_sums: cp.ndarray | None = None, - total_sq_sums: cp.ndarray | None = None, total_nnz: cp.ndarray | None = None, ) -> None: + # Wilcoxon does not output per-group variance; vars are left zero (real + # group means/pts come from group_sums/group_nnz, which drive the lfc/pts + # output and the rank test). n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] means = group_sums / n rg.means = cp.asnumpy(means) - if compute_vars: - group_ss = group_sq_sums - n * means**2 - rg.vars = cp.asnumpy(cp.maximum(group_ss / cp.maximum(n - 1, 1), 0)) - else: - rg.vars = np.zeros_like(rg.means) + rg.vars = np.zeros_like(rg.means) rg.pts = cp.asnumpy(group_nnz / n) if rg.comp_pts else None n_rest = cp.float64(n_cells) - n @@ -177,13 +173,7 @@ def _fill_basic_stats_from_accumulators( rest_sums = total_sums - group_sums rest_means = rest_sums / n_rest rg.means_rest = cp.asnumpy(rest_means) - if compute_vars: - if total_sq_sums is None: - total_sq_sums = group_sq_sums.sum(axis=0, keepdims=True) - rest_ss = (total_sq_sums - group_sq_sums) - n_rest * rest_means**2 - rg.vars_rest = cp.asnumpy(cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0)) - else: - rg.vars_rest = np.zeros_like(rg.means_rest) + rg.vars_rest = np.zeros_like(rg.means_rest) if rg.comp_pts: if total_nnz is None: total_nnz = group_nnz.sum(axis=0, keepdims=True) @@ -196,13 +186,11 @@ def _fill_basic_stats_from_accumulators( def _fill_ovo_stats_from_accumulators( rg: _RankGenes, group_sums_slots: cp.ndarray, - group_sq_sums_slots: cp.ndarray, group_nnz_slots: cp.ndarray, *, group_sizes: NDArray, test_group_indices: list[int], n_ref: int, - compute_vars: bool, ) -> None: n_test = len(test_group_indices) n_genes = int(group_sums_slots.shape[1]) @@ -221,10 +209,7 @@ def _fill_ovo_stats_from_accumulators( means_slots = group_sums_slots / slot_sizes_dev rg.means[slot_group_indices] = cp.asnumpy(means_slots) - if compute_vars: - group_ss = group_sq_sums_slots - slot_sizes_dev * means_slots**2 - denom = cp.maximum(slot_sizes_dev - 1.0, 1.0) - rg.vars[slot_group_indices] = cp.asnumpy(cp.maximum(group_ss / denom, 0)) + # vars left zero: wilcoxon does not output per-group variance. if rg.comp_pts: rg.pts[slot_group_indices] = cp.asnumpy(group_nnz_slots / slot_sizes_dev) @@ -408,8 +393,8 @@ def _device_sparse_fn(module, base_name: str, indptr: cp.ndarray): def _column_totals_for_host_matrix( - X, *, compute_sq_sums: bool, compute_nnz: bool -) -> tuple[cp.ndarray, cp.ndarray | None, cp.ndarray | None]: + X, *, compute_nnz: bool +) -> tuple[cp.ndarray, cp.ndarray | None]: n_cols = X.shape[1] if isinstance(X, sp.spmatrix | sp.sparray): data = np.asarray(X.data) @@ -422,11 +407,6 @@ def _column_totals_for_host_matrix( sums = np.zeros(n_cols, dtype=np.float64) if starts.size: sums[nonempty] = np.add.reduceat(values, starts) - sq_sums = None - if compute_sq_sums: - sq_sums = np.zeros(n_cols, dtype=np.float64) - if starts.size: - sq_sums[nonempty] = np.add.reduceat(values * values, starts) nnz = None if compute_nnz: nnz = np.zeros(n_cols, dtype=np.float64) @@ -439,13 +419,6 @@ def _column_totals_for_host_matrix( sums = np.bincount(indices, weights=values, minlength=n_cols).astype( np.float64, copy=False ) - sq_sums = ( - np.bincount(indices, weights=values * values, minlength=n_cols).astype( - np.float64, copy=False - ) - if compute_sq_sums - else None - ) nnz = ( np.bincount( indices, @@ -464,17 +437,12 @@ def _column_totals_for_host_matrix( raise TypeError(f"Unsupported host matrix type: {type(X)}") total_sums = cp.asarray(sums.reshape(1, n_cols), dtype=cp.float64) - total_sq_sums = ( - cp.asarray(sq_sums.reshape(1, n_cols), dtype=cp.float64) - if sq_sums is not None - else None - ) total_nnz = ( cp.asarray(nnz.reshape(1, n_cols), dtype=cp.float64) if nnz is not None else None ) - return total_sums, total_sq_sums, total_nnz + return total_sums, total_nnz def _host_ovr_totals_if_needed( @@ -482,14 +450,11 @@ def _host_ovr_totals_if_needed( group_codes: np.ndarray, n_groups: int, *, - compute_sq_sums: bool, compute_nnz: bool, -) -> tuple[cp.ndarray | None, cp.ndarray | None, cp.ndarray | None]: +) -> tuple[cp.ndarray | None, cp.ndarray | None]: if not np.any(group_codes == n_groups): - return None, None, None - return _column_totals_for_host_matrix( - X, compute_sq_sums=compute_sq_sums, compute_nnz=compute_nnz - ) + return None, None + return _column_totals_for_host_matrix(X, compute_nnz=compute_nnz) def wilcoxon( @@ -574,16 +539,11 @@ def _wilcoxon_vs_rest( group_sizes_np = group_sizes.astype(np.float64, copy=False) group_sizes_dev = cp.asarray(group_sizes_np, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev - compute_vars = False compute_nnz = rg.comp_pts rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) tie_corr = cp.ones(n_total_genes, dtype=cp.float64) group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) - group_sq_sums = cp.empty( - (n_groups, n_total_genes) if compute_vars else (1, 1), - dtype=cp.float64, - ) group_nnz = cp.empty( (n_groups, n_total_genes) if compute_nnz else (1, 1), dtype=cp.float64, @@ -606,13 +566,11 @@ def _wilcoxon_vs_rest( rank_sums, tie_corr, group_sums, - group_sq_sums, group_nnz, n_rows=n_cells, n_cols=n_total_genes, n_groups=n_groups, compute_tie_corr=tie_correct, - compute_sq_sums=compute_vars, compute_nnz=compute_nnz, sub_batch_cols=OVR_HOST_CSC_SUB_BATCH, ) @@ -633,35 +591,29 @@ def _wilcoxon_vs_rest( rank_sums, tie_corr, group_sums, - group_sq_sums, group_nnz, n_rows=n_cells, n_cols=n_total_genes, n_groups=n_groups, compute_tie_corr=tie_correct, - compute_sq_sums=compute_vars, compute_nnz=compute_nnz, sub_batch_cols=OVR_HOST_CSR_SUB_BATCH, ) if rg._compute_stats_in_chunks: - total_sums, total_sq_sums, total_nnz = _host_ovr_totals_if_needed( + total_sums, total_nnz = _host_ovr_totals_if_needed( X, group_codes, n_groups, - compute_sq_sums=compute_vars, compute_nnz=compute_nnz, ) _fill_basic_stats_from_accumulators( rg, group_sums, - group_sq_sums, group_nnz, group_sizes_np, n_cells=n_cells, - compute_vars=compute_vars, total_sums=total_sums, - total_sq_sums=total_sq_sums, total_nnz=total_nnz, ) @@ -887,7 +839,6 @@ def _wilcoxon_with_reference( rank_sums = cp.zeros((n_test, n_total_genes), dtype=cp.float64) tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) n_groups_stats = n_test + 1 - compute_vars = False compute_sums = rg._compute_stats_in_chunks compute_nnz = rg.comp_pts group_sums = cp.empty( @@ -896,10 +847,6 @@ def _wilcoxon_with_reference( else (1,), dtype=cp.float64, ) - group_sq_sums = cp.empty( - (n_groups_stats, n_total_genes) if compute_vars else (1,), - dtype=cp.float64, - ) group_nnz = cp.empty( (n_groups_stats, n_total_genes) if compute_nnz else (1,), dtype=cp.float64, @@ -934,7 +881,6 @@ def _wilcoxon_with_reference( rank_sums, tie_corr_arr, group_sums, - group_sq_sums, group_nnz, n_ref=n_ref, n_all_grp=n_all_grp, @@ -943,7 +889,6 @@ def _wilcoxon_with_reference( n_groups=n_test, n_groups_stats=n_groups_stats, compute_tie_corr=tie_correct, - compute_sq_sums=compute_vars, compute_nnz=compute_nnz, sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, ) @@ -964,7 +909,6 @@ def _wilcoxon_with_reference( rank_sums, tie_corr_arr, group_sums, - group_sq_sums, group_nnz, n_full_rows=X.shape[0], n_ref=n_ref, @@ -973,7 +917,6 @@ def _wilcoxon_with_reference( n_test=n_test, n_groups_stats=n_groups_stats, compute_tie_corr=tie_correct, - compute_sq_sums=compute_vars, compute_nnz=compute_nnz, compute_sums=compute_sums, sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, @@ -993,12 +936,10 @@ def _wilcoxon_with_reference( _fill_ovo_stats_from_accumulators( rg, group_sums, - group_sq_sums, group_nnz, group_sizes=group_sizes, test_group_indices=test_group_indices, n_ref=n_ref, - compute_vars=compute_vars, ) scores, p_values = _ovo_z_pvals( From 9f7d3e0bb144a95b4a9a5da38d7e64cc727cbf35 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Thu, 18 Jun 2026 00:19:45 +0200 Subject: [PATCH 20/36] improve memory and dtypes and nnz for large datasets --- src/rapids_singlecell/_cuda/rmm_scratch.cu | 23 + src/rapids_singlecell/_cuda/rmm_scratch.h | 8 + .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 71 +++ .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 29 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 115 +++-- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 10 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 435 ++++++++++++++++-- .../_cuda/wilcoxon/wilcoxon_sparse.cu | 12 + .../tools/_rank_genes_groups/_wilcoxon.py | 8 + 9 files changed, 623 insertions(+), 88 deletions(-) diff --git a/src/rapids_singlecell/_cuda/rmm_scratch.cu b/src/rapids_singlecell/_cuda/rmm_scratch.cu index efef484c..ea519d2d 100644 --- a/src/rapids_singlecell/_cuda/rmm_scratch.cu +++ b/src/rapids_singlecell/_cuda/rmm_scratch.cu @@ -2,6 +2,7 @@ #include #include +#include #include #include "rmm_scratch.h" @@ -25,3 +26,25 @@ void* rmm_allocate(size_t bytes) { void rmm_deallocate(void* ptr, size_t bytes) { rmm::mr::get_current_device_resource_ref().deallocate_sync(ptr, bytes); } + +// `fraction` * the free device memory reported by cudaMemGetInfo. +// +// Deliberately a plain query, NOT a trial-allocation probe. Probing a pool's +// internal free by allocating until it grows permanently RATCHETS the pool +// (RMM pools never shrink): repeated wilcoxon calls would grow it toward the +// whole device and then starve non-pool allocations like cudaStreamCreate +// ("out of memory" on stream creation). cudaMemGetInfo free is correct and +// safe everywhere: +// * Plain cuda: exact. +// * Pool: the memory OUTSIDE the pool's reservation; the pool also serves +// from its internal free, so this is conservative but never over-budgets +// and never grows the pool. The host-streaming paths transfer each nonzero +// once regardless of batch size (per-row cursor gather), so a smaller +// budget only adds a few more passes -- it does not re-stream. +// * Managed/UVM: device-resident free, so sizing to it avoids host spill. +size_t rmm_available_device_bytes(double fraction) { + if (fraction <= 0.0) return 0; + size_t free_b = 0, total_b = 0; + if (cudaMemGetInfo(&free_b, &total_b) != cudaSuccess) return 0; + return (size_t)(free_b * fraction); +} diff --git a/src/rapids_singlecell/_cuda/rmm_scratch.h b/src/rapids_singlecell/_cuda/rmm_scratch.h index bf746dfc..cc674e80 100644 --- a/src/rapids_singlecell/_cuda/rmm_scratch.h +++ b/src/rapids_singlecell/_cuda/rmm_scratch.h @@ -13,6 +13,14 @@ void* rmm_allocate(size_t bytes); void rmm_deallocate(void* ptr, size_t bytes); +// fraction * cudaMemGetInfo free. A plain query, never a trial-allocation probe +// (probing a pool's internal free ratchets the pool to the whole device and +// starves non-pool allocations like cudaStreamCreate). Conservative under a +// pool but safe across the default cuda resource, a pool, and managed/UVM; the +// host-streaming paths transfer each nonzero once regardless of batch size, so +// a smaller budget only adds passes. Use for every GPU-memory-budget decision. +size_t rmm_available_device_bytes(double fraction); + // --------------------------------------------------------------------------- // Small allocation pool for temporary CUDA buffers. Frees everything on scope // exit; reuse a single pool across a kernel pipeline. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 4d8402df..25aa2402 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -1,9 +1,11 @@ #pragma once +#include #include #include #include #include +#include #include #include @@ -13,6 +15,46 @@ #include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR #include "../rmm_scratch.h" // rmm_allocate, RmmScratchPool, ScopedCudaBuffer +// Host thread count for CPU-side CSR passes: hardware concurrency, capped. +static inline int host_worker_count() { + unsigned hw = std::thread::hardware_concurrency(); + return (int)std::min(hw ? hw : 4u, 32u); +} + +// Run fn(chunk, r0, r1) over a contiguous partition of [0, n); `chunk` is the +// 0-based worker index (for per-thread scratch). fn runs concurrently, so it +// must only read shared state and write disjoint output ranges (keyed by chunk +// or by [r0,r1)). Returns the number of chunks used. Serial for small n. +template +static inline int host_parallel_chunks(int n, F fn) { + if (n <= 0) return 0; + int n_threads = host_worker_count(); + if (n_threads <= 1 || n < 4096) { + fn(0, 0, n); + return 1; + } + int chunk = (n + n_threads - 1) / n_threads; + std::vector pool; + pool.reserve(n_threads); + for (int t = 0; t < n_threads; t++) { + int r0 = t * chunk; + if (r0 >= n) break; + int r1 = std::min(n, r0 + chunk); + pool.emplace_back([&fn, t, r0, r1]() { fn(t, r0, r1); }); + } + int used = (int)pool.size(); + for (std::thread& th : pool) th.join(); + return used; +} + +// Run fn(r0, r1) over a contiguous partition of [0, n) across hardware threads +// (serial for small n). fn is invoked concurrently, so it must only read shared +// state and write disjoint output ranges. Used for host-side CSR gathers. +template +static inline void host_parallel_ranges(int n, F fn) { + host_parallel_chunks(n, [&fn](int, int r0, int r1) { fn(r0, r1); }); +} + constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; constexpr int N_STREAMS = 4; @@ -131,6 +173,35 @@ static inline int checked_int_product(size_t a, size_t b, const char* context) { return (int)(a * b); } +// Largest per-batch nonzero count we let a column batch reach. A batch is +// sorted in a single CUB segmented call (int32 item count) and addressed with +// int offsets, so it must stay below INT_MAX with margin. +constexpr size_t SAFE_BATCH_NNZ = 2000000000; // < INT_MAX + +// Shrink a column sub-batch (halving) until the densest contiguous window of +// `sub_batch_cols` columns holds <= cap nonzeros, keeping every batch's nnz +// within int32 for CUB and bounding the per-stream transpose/sort scratch. +// `col_nnz(i)` returns the nonzero count of column i. Worst case returns 1 +// (a single column, whose nnz is <= n_rows). +template +static inline int cap_sub_batch_by_nnz(int n_cols, int sub_batch_cols, + size_t cap, ColNnz col_nnz) { + if (cap < 1) cap = 1; + auto max_window = [&](int s) { + size_t mx = 0; + for (int c = 0; c < n_cols; c += s) { + int e = std::min(c + s, n_cols); + size_t sum = 0; + for (int i = c; i < e; i++) sum += col_nnz(i); + if (sum > mx) mx = sum; + } + return mx; + }; + while (sub_batch_cols > 1 && max_window(sub_batch_cols) > cap) + sub_batch_cols = (sub_batch_cols + 1) / 2; + return sub_batch_cols; +} + // --------------------------------------------------------------------------- // RAII guard for cudaHostRegister. Unregisters on scope exit even when an // exception unwinds — prevents leaked host pinning on stream-sync failures. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index c325fc42..e877ea11 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -16,6 +16,16 @@ static void ovo_streaming_csr_impl( int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + // Cap sub_batch_cols so the dense group slab (n_all_grp × sub_batch_cols, + // sorted in one CUB call) stays within int32. n_all_grp is a cell count, so + // it drives the cap; the reference side is chunked separately below. + { + size_t cap = n_all_grp > 0 ? SAFE_BATCH_NNZ / (size_t)n_all_grp + : (size_t)sub_batch_cols; + if (cap < 1) cap = 1; + if ((size_t)sub_batch_cols > cap) sub_batch_cols = (int)cap; + } + std::vector h_offsets(n_groups + 1); cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost); @@ -43,11 +53,11 @@ static void ovo_streaming_csr_impl( "OVO device CSR reference group exceeds CUB int item limit"); } int ref_cache_cols = std::min(n_cols, (int)max_ref_cols); - size_t free_bytes = 0; - size_t total_bytes = 0; - if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess) { + { + // Reference cache holds 2 floats/col/ref-row; size it to ~a third of + // what the joint allocator can serve (leaving room for group buffers). size_t bytes_per_col = (size_t)n_ref * sizeof(float) * 2; - size_t target_bytes = free_bytes / 3; + size_t target_bytes = rmm_available_device_bytes(1.0 / 3.0); if (bytes_per_col > 0 && target_bytes >= bytes_per_col) { size_t mem_cols = target_bytes / bytes_per_col; if (mem_cols > 0 && mem_cols < (size_t)ref_cache_cols) { @@ -234,6 +244,17 @@ static void ovo_streaming_csc_impl( int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + // Cap sub_batch_cols so both dense slabs (n_ref × sub_batch_cols and + // n_all_grp × sub_batch_cols, each sorted in one CUB call) stay within + // int32. These row counts are cell counts, so they drive the cap. + { + size_t max_rows = (size_t)std::max(n_ref, n_all_grp); + size_t cap = + max_rows > 0 ? SAFE_BATCH_NNZ / max_rows : (size_t)sub_batch_cols; + if (cap < 1) cap = 1; + if ((size_t)sub_batch_cols > cap) sub_batch_cols = (int)cap; + } + std::vector h_offsets(n_groups + 1); cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 2b82aed9..d5cfc088 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -17,6 +17,20 @@ static void ovo_streaming_csc_host_impl( bool compute_tie_corr, bool compute_nnz, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + // Cap sub_batch_cols so neither the dense ref/group slabs (rows × + // sub_batch_cols, sorted in one CUB call) nor the per-column-batch nnz + // exceed int32. rows here are cell counts, so they dominate the dense cap. + { + size_t max_rows = (size_t)std::max(n_ref, n_all_grp); + size_t dense_cap = + max_rows > 0 ? SAFE_BATCH_NNZ / max_rows : (size_t)sub_batch_cols; + if (dense_cap < 1) dense_cap = 1; + if ((size_t)sub_batch_cols > dense_cap) sub_batch_cols = (int)dense_cap; + sub_batch_cols = cap_sub_batch_by_nnz( + n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); + } + // ---- Tier dispatch from host offsets ---- auto t1 = make_ovo_tier_plan(h_grp_offsets, n_groups); int max_grp_size = t1.max_grp_size; @@ -62,6 +76,23 @@ static void ovo_streaming_csc_host_impl( if (nnz > max_nnz) max_nnz = nnz; } + // Reduce the stream count so the per-stream scratch fits the memory budget. + // The dense ref/group slabs scale with n_ref/n_all_grp (cell counts), so at + // scale a fixed N_STREAMS would exceed GPU memory and thrash/OOM. + { + size_t per_stream = + max_nnz * (sizeof(InT) + sizeof(float) + sizeof(IndexT)) + + 2 * sub_ref_items * sizeof(float) + + (run_huge ? 2 : 1) * sub_grp_items * sizeof(float) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + (compute_nnz ? 2 : 1) * (size_t)n_groups_stats * sub_batch_cols * + sizeof(double) + + cub_temp_bytes; + size_t budget = rmm_available_device_bytes(0.8); + while (n_streams > 1 && (size_t)n_streams * per_stream > budget) + n_streams--; + } + // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; ScopedCudaStreams streams(n_streams, cudaStreamDefault); @@ -365,6 +396,10 @@ static void ovo_streaming_csr_host_impl( GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; if ((size_t)target_rows > budget_cap_rows) target_rows = (int)budget_cap_rows; + // Also bound each pack's compacted nnz: it feeds int32 CUB item counts + // and int offsets, so a dense pack must stay under INT_MAX. This splits + // dense perturbation groups across more packs. + constexpr size_t SAFE_PACK_NNZ = 1500000000; // < INT_MAX, CUB-safe int cur_first = 0; int cur_rows = 0; @@ -374,7 +409,9 @@ static void ovo_streaming_csr_host_impl( size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - h_grp_indptr_compact[h_grp_offsets[g]]); int new_rows = cur_rows + n_g; - bool can_add = (cur_rows == 0) || (new_rows <= target_rows); + bool can_add = + (cur_rows == 0) || + (new_rows <= target_rows && cur_nnz + nnz_g <= SAFE_PACK_NNZ); if (!can_add) { size_t sb_size = std::min((size_t)n_cols, @@ -469,37 +506,39 @@ static void ovo_streaming_csr_host_impl( cudaMemcpyHostToDevice); // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- + // The full-width sorted reference cache d_ref_sorted is [n_ref × n_cols], + // but it is built one COLUMN CHUNK at a time so each CUB segmented sort + // stays within int32 (n_ref × ref_chunk_cols items) and the dense extract + // scratch is bounded to a chunk instead of the whole [n_ref × n_cols] slab. + // This is what lets large references (n_ref × n_cols > INT_MAX) work. size_t ref_items = (size_t)n_ref * (size_t)n_cols; - if (n_ref > 0 && (size_t)n_cols > (size_t)std::numeric_limits::max() / - (size_t)n_ref) { - throw std::runtime_error( - "OVO host CSR dense reference cache exceeds CUB int item limit; " - "use native CSC/device sparse input or reduce genes/reference " - "size"); - } if (ref_items > std::numeric_limits::max() / (2 * sizeof(float))) { throw std::runtime_error( "OVO host CSR dense reference cache size overflows size_t"); } - size_t free_bytes = 0; - size_t total_bytes = 0; - if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess && - total_bytes > 0 && ref_items * 2 * sizeof(float) > total_bytes) { + size_t ref_avail = rmm_available_device_bytes(0.9); + if (ref_avail > 0 && ref_items * sizeof(float) > ref_avail) { throw std::runtime_error( - "OVO host CSR dense reference cache requires more GPU memory than " - "the device provides; use native CSC/device sparse input or reduce " + "OVO host CSR sorted reference cache requires more GPU memory than " + "is available; use native CSC/device sparse input or reduce " "genes/reference size"); } - int ref_items_i32 = - checked_cub_items(ref_items, "OVO host CSR dense reference cache"); + int ref_chunk_cols = + n_ref > 0 + ? (int)std::min((size_t)n_cols, SAFE_BATCH_NNZ / (size_t)n_ref) + : n_cols; + if (ref_chunk_cols < 1) ref_chunk_cols = 1; + size_t ref_chunk_items = (size_t)n_ref * (size_t)ref_chunk_cols; + int ref_chunk_items_i32 = + checked_cub_items(ref_chunk_items, "OVO host CSR ref column chunk"); float* d_ref_sorted = pool.alloc(ref_items); ScopedCudaStream ref_stream(cudaStreamNonBlocking); { ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); ScopedCudaBuffer ref_indptr_buf((n_ref + 1) * sizeof(int)); - ScopedCudaBuffer ref_dense_buf(ref_items * sizeof(float)); - ScopedCudaBuffer ref_seg_buf((n_cols + 1) * sizeof(int)); + ScopedCudaBuffer ref_dense_buf(ref_chunk_items * sizeof(float)); + ScopedCudaBuffer ref_seg_buf((ref_chunk_cols + 1) * sizeof(int)); float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); int* d_ref_indices = (int*)ref_indices_buf.data(); @@ -512,7 +551,8 @@ static void ovo_streaming_csr_host_impl( (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); // Fused gather + cast + stats for ref (fixed slot = n_test). One - // pass over PCIe, no intermediate native-dtype GPU buffer. + // pass over PCIe, no intermediate native-dtype GPU buffer. Stats for + // all columns are accumulated here, once. if (n_ref > 0 && ref_nnz > 0) { csr_gather_cast_accumulate_mapped_kernel <<>>( @@ -524,28 +564,33 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); } - // Extract ref dense (F-order) from compacted CSR. - cudaMemsetAsync(d_ref_dense, 0, ref_items * sizeof(float), ref_stream); - { + size_t ref_cub_bytes = cub_segmented_sortkeys_temp_bytes( + ref_chunk_items_i32, ref_chunk_cols); + ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); + + // Extract + segment-sort the reference one column chunk at a time, + // writing each chunk into its slice of the full-width sorted cache. + for (int cs = 0; cs < n_cols; cs += ref_chunk_cols) { + int ce = std::min(cs + ref_chunk_cols, n_cols); + int cc = ce - cs; + size_t chunk_items = (size_t)n_ref * (size_t)cc; + cudaMemsetAsync(d_ref_dense, 0, chunk_items * sizeof(float), + ref_stream); csr_extract_dense_identity_rows_unsorted_kernel <<>>( d_ref_data_f32, d_ref_indices, d_ref_indptr, d_ref_dense, - n_ref, 0, n_cols); + n_ref, cs, ce); CUDA_CHECK_LAST_ERROR( csr_extract_dense_identity_rows_unsorted_kernel); + upload_linear_offsets(d_ref_seg, cc, n_ref, ref_stream); + size_t temp = ref_cub_bytes; + cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp_buf.data(), temp, d_ref_dense, + d_ref_sorted + (size_t)cs * (size_t)n_ref, + (int)chunk_items, cc, d_ref_seg, d_ref_seg + 1, + BEGIN_BIT, END_BIT, ref_stream), + "host CSR OVO ref segmented sort"); } - - // Segmented sort ref_dense by column → ref_sorted - size_t ref_cub_bytes = - cub_segmented_sortkeys_temp_bytes(ref_items_i32, n_cols); - ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); - upload_linear_offsets(d_ref_seg, n_cols, n_ref, ref_stream); - size_t temp = ref_cub_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( - cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, - ref_items_i32, n_cols, d_ref_seg, d_ref_seg + 1, - BEGIN_BIT, END_BIT, ref_stream), - "host CSR OVO ref segmented sort"); cuda_check(cudaStreamSynchronize(ref_stream), "host CSR OVO ref sort sync"); } // ref scratch drops here diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index a5bb3f58..bb12ac3e 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -26,13 +26,19 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, * rows would silently drop or misplace nonzeros. Every caller enforces this -- * the Python dispatch calls `sort_indices()` on the CSR/CSC input before * invoking the streaming impls that launch this kernel. + * + * `row_offset` is added to the local row index when writing csc_row_idx, so a + * row-block whose indptr/data are rebased to a local [0, n_rows) range still + * records the correct global row id (used by the out-of-core row-streaming OVR + * path that feeds bulk-transferred row-blocks). Defaults to 0 for callers that + * pass the full matrix. */ template __global__ void csr_scatter_to_csc_kernel( const InT* __restrict__ data, const IndexT* __restrict__ indices, const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, - int col_start, int col_stop) { + int col_start, int col_stop, int row_offset = 0) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= n_rows) return; IndptrT rs = indptr[row]; @@ -51,7 +57,7 @@ __global__ void csr_scatter_to_csc_kernel( if (c >= col_stop) break; int dest = atomicAdd(&write_pos[c - col_start], 1); csc_vals[dest] = data[p]; - csc_row_idx[dest] = row; + csc_row_idx[dest] = row_offset + row; } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 1c71bd0f..dc6b2415 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -16,6 +16,21 @@ static void ovr_sparse_csc_host_streaming_impl( int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; + // Bound each column batch's nnz so CUB item counts stay within int32 and + // the per-stream sort buffers fit the memory budget (column counts come + // free from the CSC indptr). + { + constexpr size_t BYTES_PER_NNZ = + sizeof(InT) + 2 * sizeof(float) + 2 * sizeof(IndexT) + 8; + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = + rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + sub_batch_cols = cap_sub_batch_by_nnz( + n_cols, sub_batch_cols, cap, + [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); + } + int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; @@ -219,6 +234,244 @@ static void ovr_sparse_csc_host_streaming_impl( // Sparse-aware host-streaming CSR OVR pipeline. // ============================================================================ +/** + * Out-of-core OVR for a host CSR too large to stage on the GPU. + * + * Column indices are sorted within each row, so a per-row cursor (init 0) lets + * us walk the matrix ONCE: for each ascending column batch [col, col_end) every + * row resumes where the previous batch stopped and emits its run of columns in + * range. The cursor advances monotonically, so each nonzero is read on the host + * and bulk-transferred exactly once over the matrix's lifetime -- the gathered, + * already-compacted per-batch slice is the only thing that crosses the bus (a + * true 1x transfer, versus re-streaming the whole CSR per batch). The column + * histogram is counted on the host; the full CSR is never page-locked (the + * gather reads it on the CPU). Single stream: the per-batch CSC accumulator + * plus the gather buffers fill the device. + */ +template +static void ovr_sparse_csr_host_rowstream_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + size_t total_nnz = (size_t)h_indptr[n_rows]; + + RmmScratchPool pool; + int tpb = UTIL_BLOCK_SIZE; + size_t budget = rmm_available_device_bytes(0.8); + + // ---- Phase 0: column histogram on the host, threaded by row range. Each + // worker counts its rows' columns into a private array (cache-resident, no + // false sharing), merged afterwards. No device transfer. ---- + std::vector h_col_counts(n_cols, 0); + { + int n_workers = host_worker_count(); + std::vector> local(n_workers, + std::vector(n_cols, 0)); + int used = host_parallel_chunks(n_rows, [&](int w, int r0, int r1) { + std::vector& lc = local[w]; + for (IndptrT p = h_indptr[r0]; p < h_indptr[r1]; p++) + lc[(size_t)h_indices[p]]++; + }); + for (int w = 0; w < used; w++) + for (int c = 0; c < n_cols; c++) h_col_counts[c] += local[w][c]; + } + + // ---- Column batch size: int32 CUB limit + device buffers that fit the + // budget. Per-nnz device footprint: gathered mini-CSR (val + col) + CSC + // accumulator (val + f32 + row) + sort outputs (key + row) + CUB temp. ---- + constexpr size_t BYTES_PER_NNZ = 2 * sizeof(InT) // gather val + csc val + + 2 * sizeof(float) // f32 key in + out + + 3 * sizeof(int) // gather col + 2 rows + + 12; // CUB temp headroom + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = budget / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + sub_batch_cols = cap_sub_batch_by_nnz( + n_cols, sub_batch_cols, cap, [&](int c) { return h_col_counts[c]; }); + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i < sb; i++) + off[i + 1] = + checked_int_span((size_t)off[i] + h_col_counts[col_start + i], + "rowstream rebased column offsets"); + h_batch_nnz[b] = (size_t)off[sb]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + int mb_i32 = + checked_cub_items(max_batch_nnz, "rowstream sub-batch nnz"); + cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(mb_i32, sub_batch_cols); + } + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = + cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); + + // ---- Host gather staging (pinned for fast bulk H2D) + per-row cursor. The + // full CSR is NOT page-locked: the gather reads it on the CPU; only the + // compacted per-batch slice crosses the bus. ---- + size_t stage_nnz = max_batch_nnz ? max_batch_nnz : 1; + std::vector h_gather_vals(stage_nnz); + std::vector h_gather_cols(stage_nnz); + std::vector h_gather_indptr(n_rows + 1); + HostRegisterGuard pin_gvals(h_gather_vals.data(), stage_nnz * sizeof(InT)); + HostRegisterGuard pin_gcols(h_gather_cols.data(), stage_nnz * sizeof(int)); + std::vector cursor(n_rows, 0); // offset within each (sorted) row + + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + InT* d_gather_vals = pool.alloc(max_batch_nnz); + int* d_gather_cols = pool.alloc(max_batch_nnz); + int* d_gather_indptr = pool.alloc(n_rows + 1); + int* col_offsets = pool.alloc(sub_batch_cols + 1); + int* write_pos = pool.alloc(sub_batch_cols); + InT* csc_vals_orig = pool.alloc(max_batch_nnz); + float* csc_vals_f32 = pool.alloc(max_batch_nnz); + int* csc_row_idx = pool.alloc(max_batch_nnz); + float* keys_out = pool.alloc(max_batch_nnz); + int* vals_out = pool.alloc(max_batch_nnz); + uint8_t* cub_temp = pool.alloc(cub_temp_bytes); + double* sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + double* sub_tie_corr = pool.alloc(sub_batch_cols); + double* sub_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + double* sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + double* d_nz_scratch = + rank_use_gmem ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + + // ---- One linear pass over the matrix, column-batched. The per-row cursor + // advances monotonically (indices sorted + batches ascending), so every + // nonzero is read on the host and transferred exactly once across all + // batches -- no whole-matrix re-streaming. The host gather is threaded: + // count each row's run (binary search), prefix-sum to per-row output + // offsets, then copy rows in parallel into disjoint staging ranges. ---- + std::vector g_count(n_rows); + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int col_end = col + sb_cols; + + // Count this batch's run per row: sorted indices -> binary search from + // the cursor for the first column >= col_end. + host_parallel_ranges(n_rows, [&](int r0, int r1) { + for (int r = r0; r < r1; r++) { + const IndexT* lo = h_indices + h_indptr[r] + cursor[r]; + const IndexT* hi = h_indices + h_indptr[r + 1]; + g_count[r] = + (int)(std::lower_bound(lo, hi, (IndexT)col_end) - lo); + } + }); + // Prefix sum -> per-row output offsets (the gathered mini-CSR's row + // pointer). + h_gather_indptr[0] = 0; + for (int r = 0; r < n_rows; r++) + h_gather_indptr[r + 1] = checked_int_span( + (size_t)h_gather_indptr[r] + (size_t)g_count[r], + "rowstream gather nnz"); + int batch_nnz = h_gather_indptr[n_rows]; + // Copy each row's run into its slot and advance its cursor (disjoint + // output ranges -> race-free). + host_parallel_ranges(n_rows, [&](int r0, int r1) { + for (int r = r0; r < r1; r++) { + IndptrT base = h_indptr[r] + cursor[r]; + size_t gpos = (size_t)h_gather_indptr[r]; + int cnt = g_count[r]; + for (int k = 0; k < cnt; k++) { + h_gather_vals[gpos + k] = h_data[base + k]; + h_gather_cols[gpos + k] = (int)h_indices[base + k]; + } + cursor[r] += cnt; + } + }); + + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + cudaMemcpy(col_offsets, off, (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(write_pos, off, sb_cols * sizeof(int), + cudaMemcpyHostToDevice); + + // Bulk H2D of just this batch's compacted nonzeros (each transferred + // once over the matrix's lifetime) + the per-row split. + if (batch_nnz > 0) { + cuda_check(cudaMemcpy(d_gather_vals, h_gather_vals.data(), + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice), + "rowstream gathered vals H2D"); + cuda_check(cudaMemcpy(d_gather_cols, h_gather_cols.data(), + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice), + "rowstream gathered cols H2D"); + } + cudaMemcpy(d_gather_indptr, h_gather_indptr.data(), + (n_rows + 1) * sizeof(int), cudaMemcpyHostToDevice); + + // Scatter the gathered mini-CSR into the column-batch CSC accumulator. + csr_scatter_to_csc_kernel + <<<(n_rows + tpb - 1) / tpb, tpb>>>( + d_gather_vals, d_gather_cols, d_gather_indptr, write_pos, + csc_vals_orig, csc_row_idx, n_rows, col, col_end, 0); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + + launch_ovr_cast_and_accumulate_sparse( + csc_vals_orig, csc_vals_f32, csc_row_idx, col_offsets, + d_group_codes, sub_group_sums, sub_group_nnz, sb_cols, n_groups, + compute_nnz, tpb, smem_cast, cast_use_gmem, 0); + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + cub_temp, temp, csc_vals_f32, keys_out, csc_row_idx, + vals_out, batch_nnz, sb_cols, col_offsets, + col_offsets + 1, BEGIN_BIT, END_BIT), + "rowstream segmented sort"); + } + launch_ovr_sparse_rank( + keys_out, vals_out, col_offsets, d_group_codes, d_group_sizes, + sub_rank_sums, sub_tie_corr, d_nz_scratch, n_rows, sb_cols, + n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, 0); + + cudaMemcpy2D(d_rank_sums + col, n_cols * sizeof(double), sub_rank_sums, + sb_cols * sizeof(double), sb_cols * sizeof(double), + n_groups, cudaMemcpyDeviceToDevice); + if (compute_tie_corr) + cudaMemcpy(d_tie_corr + col, sub_tie_corr, sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice); + cudaMemcpy2D(d_group_sums + col, n_cols * sizeof(double), + sub_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice); + if (compute_nnz) + cudaMemcpy2D(d_group_nnz + col, n_cols * sizeof(double), + sub_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice); + col += sb_cols; + } + cuda_check(cudaDeviceSynchronize(), "rowstream sync"); +} + /** * Host CSR variant of the sparse OVR stream. * @@ -236,20 +489,91 @@ static void ovr_sparse_csr_host_streaming_impl( int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; + // Declared before the pool/streams so on exception unwind the streams + // drain (kernels finish reading any mapped host memory) before it is + // unregistered. + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + RmmScratchPool pool; size_t total_nnz = (size_t)h_indptr[n_rows]; - // ---- Phase 0: CPU planning in native CSR order ---- - std::vector h_col_counts(n_cols, 0); - for (int row = 0; row < n_rows; row++) { - IndptrT rs = h_indptr[row]; - IndptrT re = h_indptr[row + 1]; - for (IndptrT p = rs; p < re; ++p) { - int c = (int)h_indices[p]; - if (c >= 0 && c < n_cols) h_col_counts[c]++; + size_t budget = rmm_available_device_bytes(0.8); + + int tpb = UTIL_BLOCK_SIZE; + size_t data_bytes = total_nnz * sizeof(InT); + size_t idx_bytes = total_nnz * sizeof(IndexT); + + // When the matrix is too large to stage on the device, the per-batch + // scatter would fall back to bus-latency-bound zero-copy reads. Page the + // CSR through the GPU in row blocks (pinned bulk H2D) instead. + if (total_nnz > 0 && data_bytes + idx_bytes > (budget * 3) / 4) { + ovr_sparse_csr_host_rowstream_impl( + h_data, h_indices, h_indptr, h_group_codes, h_group_sizes, + d_rank_sums, d_tie_corr, d_group_sums, d_group_nnz, n_rows, n_cols, + n_groups, compute_tie_corr, compute_nnz, sub_batch_cols); + return; + } + + IndptrT* d_indptr_full = pool.alloc(n_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + // Stage the indices on the device when they fit, so the per-column + // histogram and the per-batch CSR->CSC scatter read them at HBM speed + // rather than over the bus. Indices are needed by both, so they are staged + // first; the (equally sized) data array is staged later only if it too + // fits. A bulk pageable copy is driver-staged -- no host registration. + IndexT* d_indices = nullptr; + bool indices_staged = total_nnz > 0 && idx_bytes <= budget / 2; + if (total_nnz > 0) { + if (indices_staged) { + d_indices = pool.alloc(total_nnz); + cuda_check(cudaMemcpy(d_indices, h_indices, idx_bytes, + cudaMemcpyHostToDevice), + "OVR host CSR stage indices H2D"); + } else { + pin_indices = HostRegisterGuard(const_cast(h_indices), + idx_bytes, cudaHostRegisterMapped); + cuda_check( + cudaHostGetDevicePointer((void**)&d_indices, + const_cast(h_indices), 0), + "OVR host CSR map indices"); } } + // ---- Phase 0: per-column nnz counts on the GPU ---- + // CSR has no column structure, so counting on the CPU is a serial pass over + // every nonzero. Histogram the device-accessible indices instead; only the + // n_cols counts come back for the per-batch prefix sums. + std::vector h_col_counts(n_cols, 0); + if (total_nnz > 0) { + unsigned int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(unsigned int)); + int hist_blocks = (n_rows + tpb - 1) / tpb; + csr_col_histogram_kernel<<>>( + d_indices, d_indptr_full, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); + cuda_check( + cudaMemcpy(h_col_counts.data(), d_col_counts, + n_cols * sizeof(unsigned int), cudaMemcpyDeviceToHost), + "OVR host CSR column-count D2H"); + } + + // Each column batch is sorted in one CUB segmented call (int32 item count) + // and its CSR->CSC transpose lives in per-stream scratch (~BYTES_PER_NNZ + // per stored nonzero). Shrink sub_batch_cols until the densest window fits + // BOTH the int32 limit AND a per-stream slice of the budget, so very tall + // matrices neither overflow CUB nor exhaust memory. + constexpr size_t BYTES_PER_NNZ = sizeof(InT) + sizeof(float) + + 2 * sizeof(int) + 8; // buffers + CUB temp + size_t batch_nnz_cap = SAFE_BATCH_NNZ; + size_t mem_cap = budget / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < batch_nnz_cap) batch_nnz_cap = mem_cap; + sub_batch_cols = + cap_sub_batch_by_nnz(n_cols, sub_batch_cols, batch_nnz_cap, + [&](int c) { return (size_t)h_col_counts[c]; }); + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; size_t max_batch_nnz = 0; std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); @@ -271,7 +595,7 @@ static void ovr_sparse_csr_host_streaming_impl( cudaMemcpy(d_all_offsets, h_all_offsets.data(), h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); - // ---- Phase 1: allocate per-stream bounded work buffers ---- + // ---- Phase 1: per-stream bounded work buffer size + stream count ---- size_t cub_temp_bytes = 0; if (max_batch_nnz > 0) { int max_batch_nnz_i32 = checked_cub_items( @@ -280,16 +604,12 @@ static void ovr_sparse_csr_host_streaming_impl( sub_batch_cols); } - int tpb = UTIL_BLOCK_SIZE; bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); bool cast_use_gmem = false; size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); - int n_streams = N_STREAMS; - if (n_batches < n_streams) n_streams = n_batches; - size_t per_stream_bytes = max_batch_nnz * (sizeof(InT) + sizeof(float) + 2 * sizeof(int)) + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + @@ -302,43 +622,39 @@ static void ovr_sparse_csr_host_streaming_impl( per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); } - size_t free_mem = 0, total_mem = 0; - cudaMemGetInfo(&free_mem, &total_mem); - constexpr double MEM_BUDGET_FRAC = 0.8; - size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); - while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + // Stage the data array too when the indices are already resident and the + // data + at least one stream's transpose buffers still fit; the scatter + // then reads values at HBM speed. Otherwise the data stays mapped + // zero-copy (bounded for matrices too large to stage). + size_t resident = indices_staged ? idx_bytes : 0; + bool data_staged = total_nnz > 0 && indices_staged && + resident + data_bytes + per_stream_bytes <= budget; + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + size_t stream_budget = budget - resident - (data_staged ? data_bytes : 0); + while (n_streams > 1 && + (size_t)n_streams * per_stream_bytes > stream_budget) n_streams--; ScopedCudaStreams streams(n_streams, cudaStreamDefault); - // Pin the source CSR arrays as mapped memory. The scatter kernel reads - // only the requested column window from each row. - HostRegisterGuard pin_data; - HostRegisterGuard pin_indices; - InT* d_data_zc = nullptr; - IndexT* d_indices_zc = nullptr; + InT* d_data = nullptr; if (total_nnz > 0) { - pin_data = - HostRegisterGuard(const_cast(h_data), total_nnz * sizeof(InT), - cudaHostRegisterMapped); - pin_indices = HostRegisterGuard(const_cast(h_indices), - total_nnz * sizeof(IndexT), - cudaHostRegisterMapped); - cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, - const_cast(h_data), 0); - cudaError_t e2 = cudaHostGetDevicePointer( - (void**)&d_indices_zc, const_cast(h_indices), 0); - if (e1 != cudaSuccess || e2 != cudaSuccess) { - throw std::runtime_error( - std::string("cudaHostGetDevicePointer failed: ") + - cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + if (data_staged) { + d_data = pool.alloc(total_nnz); + cuda_check( + cudaMemcpy(d_data, h_data, data_bytes, cudaMemcpyHostToDevice), + "OVR host CSR stage data H2D"); + } else { + pin_data = HostRegisterGuard(const_cast(h_data), data_bytes, + cudaHostRegisterMapped); + cuda_check(cudaHostGetDevicePointer((void**)&d_data, + const_cast(h_data), 0), + "OVR host CSR map data"); } } - IndptrT* d_indptr_full = pool.alloc(n_rows + 1); - cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), - cudaMemcpyHostToDevice); - int* d_group_codes = pool.alloc(n_rows); double* d_group_sizes = pool.alloc(n_groups); cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), @@ -408,7 +724,7 @@ static void ovr_sparse_csr_host_streaming_impl( if (batch_nnz > 0) { csr_scatter_to_csc_kernel <<>>( - d_data_zc, d_indices_zc, d_indptr_full, buf.write_pos, + d_data, d_indices, d_indptr_full, buf.write_pos, buf.csc_vals_orig, buf.csc_row_idx, n_rows, col, col + sb_cols); CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); @@ -485,6 +801,20 @@ static void ovr_sparse_csc_streaming_impl( cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(IndptrT), cudaMemcpyDeviceToHost); + // Bound each column batch's nnz so CUB item counts stay within int32 and + // the per-stream sort buffers fit the budget. + { + constexpr size_t BYTES_PER_NNZ = + 2 * sizeof(float) + 2 * sizeof(int) + 8; + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = + rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + sub_batch_cols = cap_sub_batch_by_nnz( + n_cols, sub_batch_cols, cap, + [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); + } + int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; @@ -642,6 +972,20 @@ static void ovr_sparse_csr_streaming_impl( cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(unsigned int), cudaMemcpyDeviceToHost); + // Bound each column batch's nnz so CUB item counts stay within int32 and + // the per-stream transpose/sort buffers fit the budget. + { + constexpr size_t BYTES_PER_NNZ = + 2 * sizeof(float) + 2 * sizeof(int) + 8; + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = + rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + sub_batch_cols = cap_sub_batch_by_nnz( + n_cols, sub_batch_cols, cap, + [&](int c) { return (size_t)h_col_counts[c]; }); + } + // Per-batch prefix sums on host int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; size_t max_batch_nnz = 0; @@ -696,10 +1040,7 @@ static void ovr_sparse_csr_streaming_impl( per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); } - size_t free_mem = 0, total_mem = 0; - cudaMemGetInfo(&free_mem, &total_mem); - constexpr double MEM_BUDGET_FRAC = 0.8; - size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + size_t budget = rmm_available_device_bytes(0.8); while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) n_streams--; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index 9feeae0b..fdb2510d 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -121,6 +121,12 @@ void register_sparse_bindings(nb::module_& m) { int); RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, int64_t); + // int64 column indices (int64 indptr): pass indices natively to avoid a + // full int32 copy of every nonzero (~nnz*4 bytes) on large matrices. + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64_idx64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64_idx64", double, + int64_t, int64_t); #undef RSC_OVR_SPARSE_CSR_HOST_BINDING #define RSC_OVO_DEVICE_BINDING(NAME, IMPL, IndptrCType) \ @@ -232,6 +238,12 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, int64_t); + // int64 column indices (int64 indptr): pass indices natively to avoid a + // full int32 copy of every nonzero (~nnz*4 bytes) on large matrices. + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64_idx64", float, int64_t, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64_idx64", double, + int64_t, int64_t); #undef RSC_OVO_CSR_HOST_BINDING } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 159b7fe0..70316b61 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -348,6 +348,14 @@ def _host_sparse_fn_and_arrays(module, base_name: str, X): suffix += "_f64" if is_i64: suffix += "_i64" + # int64 column indices: if a native-int64 binding exists for this path, use + # it and pass the indices as-is. astype(int32) on int64 indices materializes + # a full copy of every nonzero (~nnz * 4 bytes, e.g. tens of GB on large + # matrices), so avoid it when the kernel can read int64 directly. + if X.indices.dtype == np.int64: + idx_fn = getattr(module, base_name + suffix + "_idx64", None) + if idx_fn is not None: + return idx_fn, data_arr, X.indices fn = getattr(module, base_name + suffix) indices_arr = X.indices.astype(np.int32, copy=False) return fn, data_arr, indices_arr From 6327d601aaa1a3f7d49bab04a16f28b4532bcd78 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Thu, 18 Jun 2026 12:23:01 +0200 Subject: [PATCH 21/36] update comments --- docs/installation.md | 13 +-- src/rapids_singlecell/_cuda/__init__.py | 10 +-- src/rapids_singlecell/_cuda/nb_types.h | 31 ++++--- .../_cuda/rank_genes/csr_tile_to_dense.cuh | 9 +- .../_cuda/rank_genes/rank_stats.cu | 15 ++-- src/rapids_singlecell/_cuda/rmm_scratch.h | 11 +-- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 12 +-- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 75 +++++----------- .../_cuda/wilcoxon/wilcoxon.cu | 90 ++++++++++++++----- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 31 +++---- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 24 ++--- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 64 ++++++------- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 46 ++++------ .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 53 +++++------ .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 63 +++++-------- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 23 +++-- .../tools/_rank_genes_groups/__init__.py | 23 +++-- .../tools/_rank_genes_groups/_core.py | 31 +++---- .../tools/_rank_genes_groups/_utils.py | 30 ++++--- .../tools/_rank_genes_groups/_wilcoxon.py | 33 ++++--- tests/test_rank_genes_groups_wilcoxon.py | 23 ----- 21 files changed, 330 insertions(+), 380 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 35fecdc8..a21f9dcd 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -66,11 +66,14 @@ This installs the precompiled CUDA kernels but **not** the RAPIDS stack (cupy, c This is the recommended approach for **conda/mamba users** who already have RAPIDS installed in their environment. ```{note} -The compiled kernels (Wilcoxon, GMM, …) link `librmm` / `rapids_logger` at -runtime. These are **required**: they are provided by an existing RAPIDS -conda/mamba environment or by the `[rapids]`/`[rapids-cuXX]` extra below. -Installing the bare `rapids-singlecell-cuXX` wheel into an environment without -RAPIDS raises an `ImportError` when those kernels are first used. +The RAPIDS stack is **required**, not optional: `rapids_singlecell` imports +`cuml`/`cupy` at the top of its package `__init__`, and the compiled kernels +(Wilcoxon, GMM, …) link `librmm` / `rapids_logger` at runtime. These are +provided by an existing RAPIDS conda/mamba environment or by the +`[rapids]`/`[rapids-cuXX]` extra below. Installing the bare +`rapids-singlecell-cuXX` wheel into an environment without RAPIDS raises an +`ImportError` on `import rapids_singlecell` itself — not merely when a kernel is +first used. ``` ### Prebuilt wheels with RAPIDS dependencies diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 625a145f..2b93a142 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -75,14 +75,12 @@ def __getattr__(name: str): try: return importlib.import_module(f".{name}", __name__) except ModuleNotFoundError: - # Extension genuinely absent (e.g. docs builds, no-GPU installs): - # degrade to None so module-level imports don't raise. + # Extension genuinely absent (docs/no-GPU): degrade to None. return None except ImportError as exc: - # Extension present but failed to load (ABI/toolkit mismatch, a - # missing shared library, the rmm symbol-ordering issue, ...). - # Surface it with context instead of silently returning None and - # crashing later with a cryptic ``'NoneType' has no attribute ...``. + # Present but failed to load (ABI/toolkit mismatch, missing .so, rmm + # symbol-ordering): surface with context, don't return None and crash + # later with a cryptic ``'NoneType' has no attribute ...``. msg = ( f"Failed to load compiled CUDA extension {name!r}: {exc}. " "Ensure a matching rapids-singlecell-cuXX wheel (and librmm) is " diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index cd4fe4d6..23ba8958 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -8,9 +8,8 @@ namespace nb = nanobind; -/// Check the last CUDA error after a kernel launch. -/// Call immediately after every <<<...>>> launch to catch configuration errors -/// (invalid grid/block, shared memory overflow, etc.) before they propagate. +/// Check cudaGetLastError after a <<<...>>> launch (invalid grid/block, +/// shared memory overflow, etc.). inline void cuda_check_last_error(const char* kernel_name) { cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { @@ -21,12 +20,9 @@ inline void cuda_check_last_error(const char* kernel_name) { #define CUDA_CHECK_LAST_ERROR(kernel_name) cuda_check_last_error(#kernel_name) -/// Check a cudaError_t returned directly by a CUDA/CUB API call. -/// Unlike CUDA_CHECK_LAST_ERROR (which inspects cudaGetLastError after a -/// <<<...>>> launch), this validates the status a function call returns -- e.g. -/// cub::DeviceSegmentedRadixSort::SortKeys or cudaStreamSynchronize -- so a -/// failed call surfaces here with a clear label instead of as corrupted output -/// at a later synchronization point. +/// Check a cudaError_t returned directly by a CUDA/CUB API call (vs. +/// CUDA_CHECK_LAST_ERROR which inspects state after a launch), so a failed +/// call surfaces with a clear label instead of as corrupted output later. inline void cuda_check(cudaError_t err, const char* what) { if (err != cudaSuccess) { throw std::runtime_error(std::string(what) + @@ -34,6 +30,16 @@ inline void cuda_check(cudaError_t err, const char* what) { } } +/// Validate a binding-argument precondition (array dims vs. scalar shapes). +/// Throws std::invalid_argument so a mismatch is a clean Python error instead +/// of an out-of-bounds kernel launch. +inline void nb_require(bool cond, const char* what) { + if (!cond) { + throw std::invalid_argument( + std::string("rank_genes_groups CUDA binding: ") + what); + } +} + /// Per-axis cached cap on `gridDim.{x,y,z}`. These differ in CUDA: /// gridDim.x: 2^31-1 on CC 3.0+ /// gridDim.y: 65535 on most GPUs @@ -83,8 +89,7 @@ inline unsigned int strided_grid(long long nwork, int block_size) { return (unsigned int)(capped < 1 ? 1 : capped); } -/// Like `strided_grid` but for the y-axis of a 2D/3D grid (much lower cap, -/// typically 65535). Use when the y dimension is the one being strided over. +/// Like `strided_grid` but for the y-axis (much lower cap, typically 65535). inline unsigned int strided_grid_y(long long nwork, int block_size) { const long long max_grid = max_grid_dim_y(); long long ideal = (nwork + block_size - 1) / block_size; @@ -105,11 +110,11 @@ using gpu_array_c = nb::ndarray; template using gpu_array_f = nb::ndarray; -// No contiguity constraint (accepts any order) +// No contiguity constraint template using gpu_array = nb::ndarray; -// Parameterized contiguity (for kernels that handle both C and F order) +// Parameterized contiguity (kernels handling both C and F order) template using gpu_array_contig = nb::ndarray; diff --git a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh index f80ada7b..b6e881c4 100644 --- a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh +++ b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh @@ -2,13 +2,12 @@ #include -// CSR-slice + densify in a single pass: scatter the nonzeros of column window -// [col_lb, col_ub) straight into a dense (n_cells, col_ub-col_lb) F-order -// (column-major) double buffer. This skips the CSR -> CSC tile rebuild that a -// `X[:, lb:ub].tocsc()` densify would do. +// Single-pass CSR-slice + densify: scatter column window [col_lb, col_ub) into +// a dense (n_cells, col_ub-col_lb) F-order double buffer, skipping the CSR -> +// CSC rebuild a `X[:, lb:ub].tocsc()` densify would do. // // `out` must be pre-zeroed; the atomicAdd also sums duplicate column indices -// (like scipy's sum_duplicates) -- bit-identical to a dense materialization for +// (like scipy's sum_duplicates) -- bit-identical to dense materialization for // canonical CSR. Output is always double; input dtype is templated. template diff --git a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu index 9893af17..db5f9ee8 100644 --- a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu +++ b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu @@ -9,9 +9,9 @@ namespace { constexpr int GROUP_STATS_BLOCK = 256; -// Benjamini-Hochberg step-up tail: in-place reverse cumulative minimum along -// each row (group) of an already BH-scaled, p-value-sorted matrix. NaNs are -// treated as 1.0. One block per row, single thread per row (serial scan). +// Benjamini-Hochberg step-up tail: in-place reverse cumulative minimum per row +// of an already BH-scaled, p-value-sorted matrix. NaNs treated as 1.0. One +// block per row, single thread (serial scan). __global__ void fdr_bh_reverse_cummin_kernel(double* values, const int n_cols) { const int row = blockIdx.x; double running = 1.0; @@ -28,11 +28,10 @@ __global__ void fdr_bh_reverse_cummin_kernel(double* values, const int n_cols) { } } -// Per-group sum / sum-of-squares / nnz over a dense F-order (column-major) -// block of shape (n_rows x n_cols). group_codes maps each row to a group; rows -// with an out-of-range code are skipped. Outputs are (n_groups x n_cols), -// C-order, accumulated with atomics. Grid-strided so a chunk larger than the -// gridDim.x cap is still fully covered. +// Per-group sum / sum-of-squares / nnz over a dense F-order block. group_codes +// maps each row to a group; out-of-range codes are skipped. C-order +// (n_groups x n_cols) outputs accumulated with atomics. Grid-strided so chunks +// larger than the gridDim.x cap are still fully covered. __global__ void group_chunk_stats_kernel( const double* block, const int* group_codes, double* group_sums, double* group_sum_sq, double* group_nnz, const int n_rows, const int n_cols, diff --git a/src/rapids_singlecell/_cuda/rmm_scratch.h b/src/rapids_singlecell/_cuda/rmm_scratch.h index cc674e80..236c94dd 100644 --- a/src/rapids_singlecell/_cuda/rmm_scratch.h +++ b/src/rapids_singlecell/_cuda/rmm_scratch.h @@ -6,10 +6,8 @@ #include #include -// Shared RMM-backed device scratch, usable by any CUDA module that links -// rmm::rmm (see add_rmm_cuda_module in CMakeLists.txt). Allocations come from -// the current RMM device resource, so scratch participates in the same pool as -// CuPy/RAPIDS allocations. +// Shared RMM-backed device scratch (link rmm::rmm via add_rmm_cuda_module). +// Allocates from the current RMM resource, sharing CuPy/RAPIDS's pool. void* rmm_allocate(size_t bytes); void rmm_deallocate(void* ptr, size_t bytes); @@ -21,10 +19,7 @@ void rmm_deallocate(void* ptr, size_t bytes); // a smaller budget only adds passes. Use for every GPU-memory-budget decision. size_t rmm_available_device_bytes(double fraction); -// --------------------------------------------------------------------------- -// Small allocation pool for temporary CUDA buffers. Frees everything on scope -// exit; reuse a single pool across a kernel pipeline. -// --------------------------------------------------------------------------- +// Allocation pool for temporary CUDA buffers; frees everything on scope exit. struct RmmScratchPool { struct Allocation { void* ptr = nullptr; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 80e78de4..5e718d24 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -23,14 +23,8 @@ __device__ __forceinline__ double wilcoxon_block_sum(double val, return 0.0; } -/** - * OVR dense rank-sum kernel for data sorted by column. - * - * sorted_vals and sorted_row_idx are F-order arrays from a segmented - * SortPairs. One block owns one column, walks tie runs, and accumulates the - * average ranks per group without materializing a full rank matrix. - */ -// Dense OVR rank kernel. One block per column; walks sorted tie runs and +// Dense OVR rank kernel. sorted_vals/sorted_row_idx are F-order arrays from a +// segmented SortPairs. One block per column; walks sorted tie runs and // accumulates average ranks per group without materializing a rank matrix. // The `use_gmem` flag (set by ovr_smem_config) selects shared- vs // global-memory group accumulators -- CRITICAL: the use_gmem path is REQUIRED @@ -110,7 +104,7 @@ __global__ void rank_sums_from_sorted_kernel( for (int j = i; j < tie_local_end; ++j) { int grp = group_codes[si[j]]; - if (grp < n_groups) { + if (grp >= 0 && grp < n_groups) { atomicAdd(&grp_sums[grp * acc_stride], avg_rank); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index 2ae77947..6383bf5d 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -32,16 +32,11 @@ __device__ __forceinline__ double block_reduce_sum(double val, // ============================================================================ // Parallel tie correction — all threads collaborate. // -// For each unique value in the combined sorted (ref, grp) arrays, accumulate -// t^3 - t where t = count of that value. Uses two passes: -// 1. Iterate unique values in ref_col, count in both arrays. -// 2. Iterate unique values in grp_col that do NOT appear in ref_col. -// -// Incremental binary search bounds exploit monotonicity within each thread's -// stride to reduce total search work. -// -// Caller must __syncthreads() before calling. warp_buf is reused for -// reduction (32 doubles, shared memory). +// Accumulates t^3 - t per unique value of the combined (ref, grp) arrays via +// two passes: ref uniques (counted in both), then grp uniques absent from ref. +// Incremental binary search bounds exploit per-thread-stride monotonicity. +// Caller must __syncthreads() first. warp_buf (32 doubles) reused for +// reduction. // ============================================================================ __device__ __forceinline__ void compute_tie_correction_parallel( @@ -49,13 +44,12 @@ __device__ __forceinline__ void compute_tie_correction_parallel( double* warp_buf, double* out) { double local_tie = 0.0; - // Pass 1: unique values in ref_col + // Pass 1: unique values in ref_col, counted in both arrays int grp_lb = 0, grp_ub = 0; for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { if (i == 0 || ref_col[i] != ref_col[i - 1]) { float v = ref_col[i]; - // Count in ref: upper_bound from i+1 int lo = i + 1, hi = n_ref; while (lo < hi) { int m = lo + ((hi - lo) >> 1); @@ -66,7 +60,6 @@ __device__ __forceinline__ void compute_tie_correction_parallel( } int cnt_ref = lo - i; - // Count in grp: incremental lower/upper bound lo = grp_lb; hi = n_grp; while (lo < hi) { @@ -105,7 +98,6 @@ __device__ __forceinline__ void compute_tie_correction_parallel( if (i == 0 || grp_col[i] != grp_col[i - 1]) { float v = grp_col[i]; - // Incremental lower_bound in ref int lo = ref_lb, hi = n_ref; while (lo < hi) { int m = lo + ((hi - lo) >> 1); @@ -117,7 +109,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( ref_lb = lo; if (lo >= n_ref || ref_col[lo] != v) { - // Value not in ref — count in grp only (upper_bound from i+1) + // Value absent from ref — count in grp only lo = i + 1; hi = n_grp; while (lo < hi) { @@ -136,7 +128,6 @@ __device__ __forceinline__ void compute_tie_correction_parallel( } } - // Block-wide reduction double tie_sum = block_reduce_sum(local_tie, warp_buf); if (threadIdx.x == 0) { int n = n_ref + n_grp; @@ -147,12 +138,10 @@ __device__ __forceinline__ void compute_tie_correction_parallel( } // ============================================================================ -// Batched rank sums — pre-sorted (binary search, no shared memory sort) +// Batched rank sums — pre-sorted (binary search, no shared memory sort). // Used by the OVO streaming pipeline in wilcoxon_streaming.cu. -// -// Incremental binary search: each thread carries forward lower/upper bound -// positions across loop iterations, exploiting the monotonicity of the -// sorted grp_col values within each thread's stride. +// Each thread carries lower/upper bounds across iterations, exploiting +// sorted-grp_col monotonicity within its stride. // ============================================================================ __global__ void ovo_rank_huge_kernel( @@ -191,7 +180,6 @@ __global__ void ovo_rank_huge_kernel( float v = grp_col[i]; int lo, hi; - // Lower bound in ref (from ref_lb) lo = ref_lb; hi = n_ref; while (lo < hi) { @@ -204,7 +192,6 @@ __global__ void ovo_rank_huge_kernel( int n_lt_ref = lo; ref_lb = n_lt_ref; - // Upper bound in ref (from max(ref_ub, n_lt_ref)) lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; hi = n_ref; while (lo < hi) { @@ -217,7 +204,6 @@ __global__ void ovo_rank_huge_kernel( int n_eq_ref = lo - n_lt_ref; ref_ub = lo; - // Lower bound in grp (from grp_lb) lo = grp_lb; hi = n_grp; while (lo < hi) { @@ -230,7 +216,6 @@ __global__ void ovo_rank_huge_kernel( int n_lt_grp = lo; grp_lb = n_lt_grp; - // Upper bound in grp (from max(grp_ub, n_lt_grp)) lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; hi = n_grp; while (lo < hi) { @@ -300,7 +285,7 @@ __global__ void ovo_rank_large_kernel( float* grp_smem = (float*)smem_raw; double* warp_buf = (double*)(smem_raw + large_padded * sizeof(float)); - // Load group data into shared memory, pad with +INF + // Load group into smem, pad with +INF const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; for (int i = threadIdx.x; i < n_grp; i += blockDim.x) grp_smem[i] = grp_col[i]; @@ -326,8 +311,8 @@ __global__ void ovo_rank_large_kernel( } } - // Binary search each sorted grp element against sorted ref - // Incremental bounds: values are monotonic within each thread's stride + // Binary search each sorted grp element against sorted ref; + // incremental bounds (monotonic within each thread's stride) const float* ref_col = ref_sorted + (long long)col * n_ref; int ref_lb = 0, ref_ub = 0; int grp_lb = 0, grp_ub = 0; @@ -389,14 +374,13 @@ __global__ void ovo_rank_large_kernel( ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; } - // Block reduction → write rank_sums double total = block_reduce_sum(local_sum, warp_buf); if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; if (!compute_tie_corr) return; __syncthreads(); - // Parallel tie correction (grp_smem is sorted shared memory) + // grp_smem is sorted here compute_tie_correction_parallel(ref_col, n_ref, grp_smem, n_grp, warp_buf, &tie_corr[grp * n_cols + col]); } @@ -681,8 +665,8 @@ __device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, int lane = threadIdx.x & 31; double local_tie = 0.0; - // Pass 1: for each unique value in ref_col, count occurrences in ref and - // in the sorted group (held in register v_lane across 32 lanes). + // Pass 1: per unique ref value, count occurrences in ref and in the + // sorted group (held in register v_lane across 32 lanes). for (int base = 0; base < n_ref; base += 32) { int i = base + lane; bool in_ref_lane = (i < n_ref); @@ -690,7 +674,6 @@ __device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, bool is_first = in_ref_lane && ((i == 0) || (v != ref_col[i - 1])); int cnt_ref = 0; if (is_first) { - // Count in ref: upper_bound from i+1 int lo = i + 1, hi = n_ref; while (lo < hi) { int m = lo + ((hi - lo) >> 1); @@ -734,7 +717,6 @@ __device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, bool first_in_grp = (lane == 0) || (v != v_prev); bool in_ref = false; if (first_in_grp) { - // Binary search in ref. int lo = 0, hi = n_ref; while (lo < hi) { int m = lo + ((hi - lo) >> 1); @@ -765,7 +747,6 @@ __device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, } } - // Warp reduce. #pragma unroll for (int off = 16; off > 0; off >>= 1) local_tie += __shfl_down_sync(0xffffffff, local_tie, off); @@ -835,17 +816,11 @@ __device__ __forceinline__ double warp_tie_delta(const float* ref_col, // ============================================================================ // WARP-band kernel: warp-per-(col, group) pair, 8 warps packed per block. // -// Each warp independently: -// 1. Loads ≤ 32 group values into a single register (one per lane, -// padded with +INF). -// 2. Bitonic-sorts via __shfl_xor_sync — no smem, no __syncthreads. -// 3. Binary-searches into sorted ref for each lane's value and -// accumulates the rank-sum term. -// 4. Warp-shuffle reduces to lane 0 and writes rank_sums / tie_corr. -// -// 8 (col, group) pairs per block cuts block count 8× vs the block-per-pair -// LARGE band, and the lack of __syncthreads / smem sort lets each warp run -// independently at full throughput. +// Each warp independently loads ≤32 group values into registers (one per +// lane), bitonic-sorts via __shfl_xor_sync, binary-searches into sorted ref, +// and warp-reduces to lane 0. 8 pairs/block cuts block count 8× vs the +// block-per-pair LARGE band; no smem/__syncthreads lets warps run at full +// throughput independently. // // Grid: (n_cols, ceil(n_groups / 8)), Block: 256. // ============================================================================ @@ -907,7 +882,6 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, if (lane < n_grp) { float v = x; - // Lower bound in ref. int lo = 0, hi = n_ref; while (lo < hi) { int m = lo + ((hi - lo) >> 1); @@ -917,7 +891,6 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, hi = m; } int n_lt_ref = lo; - // Upper bound in ref. hi = n_ref; while (lo < hi) { int m = lo + ((hi - lo) >> 1); @@ -947,9 +920,8 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, } } int n_eq_grp_total = n_eq_grp_offset + n_eq_grp_after; - // Contribution: rank = n_lt_ref + n_lt_grp + (n_eq_ref + - // n_eq_grp_total + 1) / 2, but we sum per lane so each tie lane - // gets the same mid-rank. This matches the LARGE-band accumulation. + // Per-lane mid-rank; each tie lane gets the same value (matches LARGE + // band). local_sum = (double)(n_lt_ref + n_lt_grp) + ((double)(n_eq_ref + n_eq_grp_total) + 1.0) / 2.0; } @@ -962,7 +934,6 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, if (!compute_tie_corr) return; - // Warp-scoped tie correction. double tie_sum; if (ref_tie_sums != nullptr) { tie_sum = ref_tie_sums[col] + diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 03a0f0da..c1c30268 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -93,9 +93,10 @@ static void launch_ovr_rank_dense_streaming( "dense OVR segmented sort"); if (use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); + cuda_check(cudaMemsetAsync( + buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream), + "dense OVR gmem rank_sums memset"); } rank_sums_from_sorted_kernel<<>>( buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, @@ -103,14 +104,17 @@ static void launch_ovr_rank_dense_streaming( use_gmem); CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + cuda_check( + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense OVR rank_sums D2D copy"); if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, - stream); + cuda_check(cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream), + "dense OVR tie_corr D2D copy"); } col += sb_cols; @@ -136,9 +140,11 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; std::vector h_offsets(n_groups + 1); - cudaStreamSynchronize(upstream_stream); - cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyDeviceToHost); + cuda_check(cudaStreamSynchronize(upstream_stream), + "dense OVO sync before offsets D2H"); + cuda_check(cudaMemcpy(h_offsets.data(), grp_offsets, + (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost), + "dense OVO group offsets D2H"); auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); int max_grp_size = t1.max_grp_size; bool run_large = t1.above_medium && t1.run_large; @@ -188,9 +194,10 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( int* d_sort_group_ids = nullptr; if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); - cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), - h_sort_group_ids.size() * sizeof(int), - cudaMemcpyHostToDevice); + cuda_check(cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice), + "dense OVO sort group ids H2D"); } struct StreamBuf { @@ -270,15 +277,19 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, stream); - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), - buf.sub_tie_corr, sb_cols * sizeof(double), + cuda_check( + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + cudaMemcpyDeviceToDevice, stream), + "dense OVO rank_sums D2D copy"); + if (compute_tie_corr) { + cuda_check( + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense OVO tie_corr D2D copy"); } col += sb_cols; @@ -308,6 +319,24 @@ void register_bindings(nb::module_& m) { gpu_array_c tie_corr, int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { + nb_require(ref_data.ndim() == 2 && grp_data.ndim() == 2 && + rank_sums.ndim() == 2 && tie_corr.ndim() == 2 && + grp_offsets.ndim() == 1, + "ovo_rank: data/outputs must be 2D, grp_offsets 1D"); + nb_require((int)ref_data.shape(0) == n_ref && + (int)ref_data.shape(1) == n_cols, + "ovo_rank: ref_data shape must be (n_ref, n_cols)"); + nb_require((int)grp_data.shape(0) == n_all_grp && + (int)grp_data.shape(1) == n_cols, + "ovo_rank: grp_data shape must be (n_all_grp, n_cols)"); + nb_require((int)grp_offsets.shape(0) >= n_groups + 1, + "ovo_rank: grp_offsets length must be >= n_groups + 1"); + nb_require((int)rank_sums.shape(0) == n_groups && + (int)rank_sums.shape(1) == n_cols, + "ovo_rank: rank_sums shape must be (n_groups, n_cols)"); + nb_require((int)tie_corr.shape(0) == n_groups && + (int)tie_corr.shape(1) == n_cols, + "ovo_rank: tie_corr shape must be (n_groups, n_cols)"); launch_ovo_rank_dense_tiered_unsorted_ref( ref_data.data(), grp_data.data(), grp_offsets.data(), rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, @@ -327,6 +356,19 @@ void register_bindings(nb::module_& m) { gpu_array_c tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { + nb_require(block.ndim() == 2 && rank_sums.ndim() == 2 && + group_codes.ndim() == 1 && tie_corr.ndim() == 1, + "ovr_rank: block/rank_sums 2D, group_codes/tie_corr 1D"); + nb_require( + (int)block.shape(0) == n_rows && (int)block.shape(1) == n_cols, + "ovr_rank: block shape must be (n_rows, n_cols)"); + nb_require((int)group_codes.shape(0) == n_rows, + "ovr_rank: group_codes length must be n_rows"); + nb_require((int)rank_sums.shape(0) == n_groups && + (int)rank_sums.shape(1) == n_cols, + "ovr_rank: rank_sums shape must be (n_groups, n_cols)"); + nb_require((int)tie_corr.shape(0) == n_cols, + "ovr_rank: tie_corr length must be n_cols"); launch_ovr_rank_dense_streaming( block.data(), group_codes.data(), rank_sums.data(), tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 25aa2402..da1aa18c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -61,34 +61,27 @@ constexpr int N_STREAMS = 4; constexpr int SUB_BATCH_COLS = 64; constexpr int BEGIN_BIT = 0; constexpr int END_BIT = 32; -// Default thread-per-block for utility kernels (extract, gather, offsets, -// etc.). +// Default thread-per-block for utility kernels. constexpr int UTIL_BLOCK_SIZE = 256; // Scratch slots for warp-level reduction (one slot per warp, 32 warps max). constexpr int WARP_REDUCE_BUF = 32; -// Max group size for the super-fast "warp-per-(col,group)" fused kernel (the -// WARP band). Each warp sorts and ranks one (col, group) pair entirely in -// registers via warp-shuffle bitonic sort — no smem sort buffer, no -// __syncthreads(). Blocks pack 8 warps so block launch overhead is -// amortised 8× across (col, group) work items. This path is the fast -// route for per-celltype perturbation-style workloads where most test -// groups have only a few dozen cells. +// WARP band: warp-per-(col,group) fused kernel. Each warp sorts+ranks one +// pair entirely in registers (warp-shuffle bitonic, no smem, no __syncthreads). +// Blocks pack 8 warps to amortise launch overhead. Fast route for +// perturbation-style workloads where most groups have a few dozen cells. constexpr int OVO_WARP_MAX = 32; -// SMALL band for perturbation workloads where most groups are slightly larger -// than one warp. Uses one compact shared-memory sort block per (column, -// group), avoiding the heavier MEDIUM-band in-group scan. +// SMALL band: groups slightly larger than one warp. One compact smem sort +// block per (col, group), avoiding the heavier MEDIUM-band in-group scan. constexpr int OVO_SMALL_MAX = 64; -// Medium-group cutoff for the unsorted direct-rank kernel. For perturbation -// workloads most groups sit below this range, where avoiding a full smem -// bitonic sort wins despite the O(n^2) in-group count. +// MEDIUM band: unsorted direct-rank kernel. Avoiding a full smem bitonic sort +// wins here despite the O(n^2) in-group count. constexpr int OVO_MEDIUM_MAX = 512; // Max group size for the fused smem-sort rank kernel (the LARGE band). // Beyond this, fall back to the HUGE band: CUB segmented sort + rank kernel. constexpr int OVO_LARGE_MAX = 2500; -// Per-stream dense slab budget (float32 items). Dynamic sub-batching sizes -// each group's column batch so that (n_g × eff_sb_cols) ≤ this. Bigger = -// fewer kernel launches; smaller = less per-stream memory. 128M items × 4B = -// 512 MB per stream dense slab + same for sorted copy ≈ 1 GB / stream. +// Per-stream dense slab budget (float32 items). Sub-batching keeps +// (n_g × eff_sb_cols) ≤ this. 128M × 4B = 512 MB slab + same for sorted copy +// ≈ 1 GB / stream. Bigger = fewer launches; smaller = less per-stream memory. constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; // Query CUB device-segmented-radix-sort scratch size with a dummy launch. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index e877ea11..71c4b64e 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -27,8 +27,9 @@ static void ovo_streaming_csr_impl( } std::vector h_offsets(n_groups + 1); - cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyDeviceToHost); + cuda_check(cudaMemcpy(h_offsets.data(), grp_offsets, + (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost), + "device OVO group offsets D2H"); auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); int max_grp_size = t1.max_grp_size; bool run_large = t1.above_medium && t1.run_large; @@ -88,9 +89,10 @@ static void ovo_streaming_csr_impl( int* d_sort_group_ids = nullptr; if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); - cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), - h_sort_group_ids.size() * sizeof(int), - cudaMemcpyHostToDevice); + cuda_check(cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice), + "device OVO sort group ids H2D"); } struct StreamBuf { @@ -256,8 +258,9 @@ static void ovo_streaming_csc_impl( } std::vector h_offsets(n_groups + 1); - cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyDeviceToHost); + cuda_check(cudaMemcpy(h_offsets.data(), grp_offsets, + (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost), + "device OVO group offsets D2H"); auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); int max_grp_size = t1.max_grp_size; bool run_large = t1.above_medium && t1.run_large; @@ -300,9 +303,10 @@ static void ovo_streaming_csc_impl( int* d_sort_group_ids = nullptr; if (run_huge) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); - cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), - h_sort_group_ids.size() * sizeof(int), - cudaMemcpyHostToDevice); + cuda_check(cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice), + "device OVO sort group ids H2D"); } struct StreamBuf { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index d5cfc088..218d075c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -31,7 +31,6 @@ static void ovo_streaming_csc_host_impl( [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); } - // ---- Tier dispatch from host offsets ---- auto t1 = make_ovo_tier_plan(h_grp_offsets, n_groups); int max_grp_size = t1.max_grp_size; bool run_large = t1.above_medium && t1.run_large; @@ -95,6 +94,13 @@ static void ovo_streaming_csc_host_impl( // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; + // Pin host inputs before the streams so on an exception unwind the streams + // drain before the buffers are unregistered (mirrors the safe CSR order). + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); ScopedCudaStreams streams(n_streams, cudaStreamDefault); int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; @@ -115,7 +121,7 @@ static void ovo_streaming_csc_host_impl( cudaMemcpy(d_all_offsets, h_all_offsets.data(), h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); - // GPU copies of row maps + group offsets + stats codes (uploaded once) + // Row maps + group offsets + stats codes (uploaded once) int* d_ref_row_map = pool.alloc(n_rows); int* d_grp_row_map = pool.alloc(n_rows); int* d_grp_offsets = pool.alloc(n_groups + 1); @@ -198,13 +204,6 @@ static void ovo_streaming_csc_host_impl( size_t smem_cast = cast_accumulate_smem_config(n_groups_stats, compute_nnz, cast_use_gmem); - // Pin only the sparse input arrays; outputs live on the device. - size_t total_nnz = (size_t)h_indptr[n_cols]; - HostRegisterGuard _pin_data(const_cast(h_data), - total_nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - total_nnz * sizeof(IndexT)); - int col = 0; int batch_idx = 0; while (col < n_cols) { @@ -219,7 +218,7 @@ static void ovo_streaming_csc_host_impl( auto stream = streams[s]; auto& buf = bufs[s]; - // ---- H2D: sparse data for this column range (native dtype) ---- + // H2D: sparse data for this column range (native dtype) IndptrT ptr_start = h_indptr[col]; IndptrT ptr_end = h_indptr[col + sb_cols]; size_t nnz = (size_t)(ptr_end - ptr_start); @@ -232,14 +231,14 @@ static void ovo_streaming_csc_host_impl( cudaMemcpyAsync(buf.d_indptr, src, (sb_cols + 1) * sizeof(int), cudaMemcpyDeviceToDevice, stream); - // ---- Cast to float32 for sort + accumulate stats in float64 ---- + // Cast to float32 for sort + accumulate stats in float64 launch_ovr_cast_and_accumulate_sparse( buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_nnz, sb_cols, n_groups_stats, compute_nnz, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); - // ---- Extract ref from CSC via row_map, sort ---- + // Extract ref from CSC via row_map, sort cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), stream); csc_extract_mapped_kernel<<>>( @@ -256,7 +255,7 @@ static void ovo_streaming_csc_host_impl( "host CSC OVO ref segmented sort"); } - // ---- Extract grp from CSC via row_map ---- + // Extract grp from CSC via row_map cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), stream); csc_extract_mapped_kernel<<>>( @@ -264,7 +263,7 @@ static void ovo_streaming_csc_host_impl( d_grp_row_map, buf.grp_dense, n_all_grp, 0); CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); - // ---- Tier dispatch: sort grp + rank ---- + // Tier dispatch: sort grp + rank OvoTierScratch sc{buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, buf.grp_sorted, buf.grp_seg_offsets, buf.grp_seg_ends, @@ -274,7 +273,7 @@ static void ovo_streaming_csc_host_impl( sb_grp_actual, tpb_rank, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, stream); - // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + // D2D: scatter sub-batch results into caller's GPU buffers cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, @@ -341,7 +340,7 @@ static void ovo_streaming_csr_host_impl( bool compute_nnz, bool compute_sums, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; - // ---- Pre-compute compacted indptrs on host (O(n_ref + n_all_grp)) ---- + // Pre-compute compacted indptrs on host (O(n_ref + n_all_grp)). // Use IndptrT for the global compacted indptr because the grp side can // exceed 2^31 nnz on very large / dense matrices. Ref always fits in // int32 since n_ref × n_cols ≪ 2B; keeping int32 there matches the @@ -374,7 +373,7 @@ static void ovo_streaming_csr_host_impl( h_grp_indptr_compact[i + 1] = h_grp_indptr_compact[i] + nnz_i; } - // ---- Build packs (same rule as grp_impl, but uses compacted indptr) ---- + // Build packs (same rule as grp_impl, but uses compacted indptr) struct Pack { int first; int end; @@ -452,7 +451,6 @@ static void ovo_streaming_csr_host_impl( RmmScratchPool pool; - // Zero stats outputs. if (compute_sums) { cudaMemsetAsync(d_group_sums, 0, (size_t)n_groups_stats * n_cols * sizeof(double)); @@ -462,7 +460,7 @@ static void ovo_streaming_csr_host_impl( (size_t)n_groups_stats * n_cols * sizeof(double)); } - // ---- Pin full host data + indices as MAPPED (zero-copy accessible) ---- + // Pin full host data + indices as MAPPED (zero-copy accessible) size_t full_nnz = (size_t)h_indptr[n_full_rows]; HostRegisterGuard _pin_data(const_cast(h_data), full_nnz * sizeof(InT), cudaHostRegisterMapped); @@ -486,12 +484,12 @@ static void ovo_streaming_csr_host_impl( } } - // ---- Upload full indptr (keep native IndptrT — can exceed int32) ---- + // Upload full indptr (keep native IndptrT — can exceed int32) IndptrT* d_indptr_full = pool.alloc(n_full_rows + 1); cudaMemcpy(d_indptr_full, h_indptr, (n_full_rows + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); - // ---- Upload row_ids + compacted indptrs + group boundaries ---- + // Upload row_ids + compacted indptrs + group boundaries int* d_ref_row_ids = pool.alloc(n_ref); int* d_grp_row_ids = pool.alloc(n_all_grp); IndptrT* d_grp_indptr_compact = pool.alloc(n_all_grp + 1); @@ -505,7 +503,7 @@ static void ovo_streaming_csr_host_impl( cudaMemcpy(d_grp_offsets_full, h_grp_offsets, (n_test + 1) * sizeof(int), cudaMemcpyHostToDevice); - // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- + // Phase 1: Ref setup (scoped scratch, ref_sorted persists). // The full-width sorted reference cache d_ref_sorted is [n_ref × n_cols], // but it is built one COLUMN CHUNK at a time so each CUB segmented sort // stays within int32 (n_ref × ref_chunk_cols items) and the dense extract @@ -546,13 +544,11 @@ static void ovo_streaming_csr_host_impl( float* d_ref_dense = (float*)ref_dense_buf.data(); int* d_ref_seg = (int*)ref_seg_buf.data(); - // Upload ref compacted indptr cudaMemcpy(d_ref_indptr, h_ref_indptr_compact.data(), (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); - // Fused gather + cast + stats for ref (fixed slot = n_test). One - // pass over PCIe, no intermediate native-dtype GPU buffer. Stats for - // all columns are accumulated here, once. + // Fused gather + cast + stats for ref (fixed slot = n_test): one PCIe + // pass, no intermediate native-dtype buffer, all-column stats once. if (n_ref > 0 && ref_nnz > 0) { csr_gather_cast_accumulate_mapped_kernel <<>>( @@ -568,8 +564,7 @@ static void ovo_streaming_csr_host_impl( ref_chunk_items_i32, ref_chunk_cols); ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); - // Extract + segment-sort the reference one column chunk at a time, - // writing each chunk into its slice of the full-width sorted cache. + // Extract + segment-sort the reference one column chunk at a time. for (int cs = 0; cs < n_cols; cs += ref_chunk_cols) { int ce = std::min(cs + ref_chunk_cols, n_cols); int cc = ce - cs; @@ -595,7 +590,7 @@ static void ovo_streaming_csr_host_impl( "host CSR OVO ref sort sync"); } // ref scratch drops here - // ---- Phase 2: Per-pack streaming ---- + // Phase 2: Per-pack streaming auto t1 = make_ovo_tier_plan(h_grp_offsets, n_test); bool may_need_cub = (t1.max_grp_size > OVO_LARGE_MAX); @@ -709,8 +704,8 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); } - // Build per-pack group offsets on GPU (on this stream) — needed to - // compute stats codes before the fused gather kernel can run. + // Build per-pack group offsets on GPU — needed for stats codes before + // the fused gather kernel can run. { int count = K + 1; int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; @@ -719,7 +714,6 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); } - // Fill per-row stats codes for this pack { int blk = (pack_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; fill_pack_stats_codes_kernel<<>>( @@ -727,9 +721,8 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); } - // Fused gather + cast + stats for the pack. One pass over PCIe - // (reads mapped host via UVA), no intermediate native-dtype GPU - // buffer, writes f32 + indices + atomics. + // Fused gather + cast + stats for the pack: one PCIe pass (reads mapped + // host via UVA), no intermediate native-dtype buffer. if (pack.nnz > 0) { csr_gather_cast_accumulate_mapped_kernel <<>>( @@ -742,7 +735,6 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); } - // Per col sub-batch int col = 0; while (col < n_cols) { int sb_cols = std::min(pack_sb, n_cols - col); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index f162782a..84ed7fab 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -3,10 +3,9 @@ #include /** - * Build CUB segmented-sort ranges only for groups in the HUGE band. - * Group ids are relative to grp_offsets, and ranges still point into the - * original dense group layout so the presorted rank kernel can read from the - * normal per-group positions. + * Build CUB segmented-sort ranges for HUGE-band groups. Ranges point into the + * original dense group layout so the presorted rank kernel reads normal + * per-group positions. */ __global__ void build_huge_seg_offsets_kernel( const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, @@ -25,10 +24,9 @@ __global__ void build_huge_seg_offsets_kernel( } /** - * Extract specific rows from CSC into dense F-order, using a row lookup map. + * Extract rows from CSC into dense F-order via a row lookup map. * row_map[original_row] = output_row_index (or -1 to skip). - * One block per column, threads scatter matching nonzeros. - * Output must be pre-zeroed. + * One block per column. Output must be pre-zeroed. */ template __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, @@ -52,12 +50,10 @@ __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, } /** - * LARGE-band dispatch: when the largest group fits in shared memory, a fused - * bitonic-sort + binary-search kernel handles the whole group per block. - * Otherwise we fall back to the HUGE band (CUB segmented sort plus the - * pre-sorted rank kernel). This struct bundles the sizing knobs derived from - * the host-side group offsets so each streaming impl can drop a 15-line prep - * block. + * Sizing knobs for LARGE-band dispatch: when the largest group fits in shared + * memory, a fused bitonic-sort + binary-search kernel handles the group per + * block; otherwise fall back to the HUGE band (CUB segmented sort + pre-sorted + * rank kernel). */ struct OvoTierPlan { int max_grp_size = 0; @@ -102,10 +98,7 @@ static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { } if (n_groups == 0) c.min_grp_size = 0; - // run_warp: WARP kernel is worth running (at least one group small - // enough to benefit from the warp path). c.run_warp = (c.min_grp_size <= OVO_WARP_MAX); - // above_warp: at least one group needs a non-WARP kernel. c.above_warp = (c.max_grp_size > OVO_WARP_MAX); // run_large: the fused smem-sort fast path (groups > WARP but ≤ LARGE). c.run_large = c.above_warp && (c.max_grp_size <= OVO_LARGE_MAX); @@ -115,11 +108,10 @@ static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { c.large_tpb = std::min(c.large_padded, MAX_THREADS_PER_BLOCK); c.large_smem = (size_t)c.large_padded * sizeof(float) + WARP_REDUCE_BUF * sizeof(double); - // Adapt to the device: if the fused-sort buffer would exceed the - // per-block shared-memory limit, fall back to the HUGE-band CUB - // segmented sort (no smem cap) rather than launching a kernel that - // would fail. Never triggers at the current threshold (~16.6KB), but - // keeps the dispatch correct if the threshold or device limit changes. + // Device-adapt: if the fused-sort buffer exceeds the per-block smem + // limit, fall back to HUGE (no smem cap) instead of launching a kernel + // that would fail. Inert at the current ~16.6KB threshold; guards + // against threshold/device-limit changes. if (c.large_smem > wilcoxon_max_smem_per_block()) { c.run_large = false; } @@ -206,13 +198,11 @@ struct OvoTierScratch { }; // SINGLE OVO ranking engine, shared by the dense path and all four sparse OVO -// impls (host/device CSC/CSR). Given an already-sorted reference slice and a -// dense group slice for one column sub-batch, it runs the size-banded dispatch -// from `plan` (see make_ovo_tier_plan): co-launch WARP/SMALL/MEDIUM for small -// groups, then LARGE (fused smem sort) OR HUGE (CUB segmented sort) for the -// rest. Pure host-side code motion: the kernel launches are identical to the -// previous inline copies, so results and performance are unchanged. The five -// callers differ only in how they produce ref_sorted / grp_dense. +// impls (host/device CSC/CSR). Given a sorted reference slice and a dense group +// slice for one column sub-batch, runs the size-banded dispatch from `plan` +// (see make_ovo_tier_plan): co-launch WARP/SMALL/MEDIUM for small groups, then +// LARGE (fused smem sort) OR HUGE (CUB segmented sort) for the rest. Callers +// differ only in how they produce ref_sorted / grp_dense. static inline void ovo_dispatch_tiers( const float* ref_sorted, const float* grp_dense, const int* grp_offsets, const OvoTierPlan& plan, const OvoTierScratch& sc, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index bb12ac3e..882d2e0e 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -18,20 +18,17 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, /** * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). - * write_pos[c - col_start] must be initialized to the prefix-sum offset - * for column c. Each thread atomically claims a unique destination slot. + * write_pos[c - col_start] is the prefix-sum offset for column c; each thread + * atomically claims a unique destination slot. * - * PRECONDITION: each row's `indices` must be sorted ascending. The binary - * search for col_start and the `break` at col_stop both depend on it; unsorted - * rows would silently drop or misplace nonzeros. Every caller enforces this -- - * the Python dispatch calls `sort_indices()` on the CSR/CSC input before - * invoking the streaming impls that launch this kernel. + * PRECONDITION: each row's `indices` must be sorted ascending -- the binary + * search for col_start and the `break` at col_stop depend on it; unsorted rows + * would silently drop or misplace nonzeros. Python dispatch calls + * `sort_indices()` before launching this kernel. * - * `row_offset` is added to the local row index when writing csc_row_idx, so a - * row-block whose indptr/data are rebased to a local [0, n_rows) range still - * records the correct global row id (used by the out-of-core row-streaming OVR - * path that feeds bulk-transferred row-blocks). Defaults to 0 for callers that - * pass the full matrix. + * `row_offset` is added to the local row index so a row-block rebased to a + * local [0, n_rows) range still records the correct global row id (out-of-core + * row-streaming OVR path). Defaults to 0 for full-matrix callers. */ template __global__ void csr_scatter_to_csc_kernel( @@ -63,14 +60,12 @@ __global__ void csr_scatter_to_csc_kernel( // CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). // -// Decide smem-vs-gmem for the DENSE OVR rank kernel -// (rank_sums_from_sorted_kernel). Per-block accumulator is one double per group -// plus a 32-slot warp buffer, i.e. (n_groups + 32) doubles. When that exceeds -// the per-block smem limit (~48 KB) the kernel must fall back to a -// global-memory accumulator (use_gmem=true). With a 48 KB limit this flips at -// roughly n_groups > 6112. Not dead: a kernel launched in smem mode with an -// oversized request simply fails to launch. Limit is device-queried via -// wilcoxon_max_smem_per_block(), so it auto-scales. +// Decide smem-vs-gmem for the DENSE OVR rank kernel. Per-block accumulator is +// (n_groups + 32) doubles; when that exceeds the per-block smem limit (~48 KB) +// it must fall back to a global-memory accumulator (use_gmem=true), flipping at +// roughly n_groups > 6112. Not dead: smem mode with an oversized request fails +// to launch. Limit is device-queried via wilcoxon_max_smem_per_block(), so it +// auto-scales. static size_t ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(n_groups + 32) * sizeof(double); if (need <= wilcoxon_max_smem_per_block()) { @@ -86,16 +81,14 @@ static size_t ovr_smem_config(int n_groups, bool& use_gmem) { * CRITICAL — DO NOT REMOVE the gmem branch. This is the load-bearing path for * Perturb-seq / pooled-CRISPR DE, where n_groups is in the thousands. * - * Decide smem-vs-gmem for the sparse OVR rank kernel. The per-block accumulator - * is two double arrays of size n_groups (grp_sums + grp_nz_count) plus a - * 32-slot warp buffer, i.e. (2*n_groups + 32) doubles. When that exceeds the - * per-block shared-memory limit (~48 KB) the kernel CANNOT launch in smem mode, - * so we set use_gmem=true and rank_sums_sparse_ovr_kernel accumulates in a - * caller-provided global-memory buffer instead. With a 48 KB limit this flips - * at roughly n_groups > 3056. Reviewers/static analysis have twice mistaken - * this fallback for dead code; it is the ONLY path that works at large - * n_groups. The limit is queried per device via wilcoxon_max_smem_per_block(), - * so the threshold auto-scales with the GPU. + * Decide smem-vs-gmem for the sparse OVR rank kernel. Per-block accumulator is + * (2*n_groups + 32) doubles (grp_sums + grp_nz_count + warp buf); when that + * exceeds the per-block smem limit (~48 KB) the kernel CANNOT launch in smem + * mode, so use_gmem=true routes accumulation to a caller-provided gmem buffer. + * Flips at roughly n_groups > 3056. Reviewers/static analysis have twice + * mistaken this fallback for dead code; it is the ONLY path that works at large + * n_groups. Limit is device-queried via wilcoxon_max_smem_per_block(), so the + * threshold auto-scales. */ static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index dc6b2415..6ee17a0a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -35,7 +35,6 @@ static void ovr_sparse_csc_host_streaming_impl( if (n_cols < n_streams * sub_batch_cols) n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - // Find max nnz across any sub-batch size_t max_nnz = 0; for (int col = 0; col < n_cols; col += sub_batch_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); @@ -43,7 +42,6 @@ static void ovr_sparse_csc_host_streaming_impl( if (nnz > max_nnz) max_nnz = nnz; } - // CUB temp size for max_nnz items size_t cub_temp_bytes = 0; if (max_nnz > 0) { int max_nnz_i32 = @@ -54,6 +52,13 @@ static void ovr_sparse_csc_host_streaming_impl( // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; + // Pin host inputs before the streams so on an exception unwind the streams + // drain before the buffers are unregistered (mirrors the safe CSR order). + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); ScopedCudaStreams streams(n_streams, cudaStreamDefault); int* d_group_codes = pool.alloc(n_rows); double* d_group_sizes = pool.alloc(n_groups); @@ -90,7 +95,6 @@ static void ovr_sparse_csc_host_streaming_impl( : nullptr; } - // Transfer group codes + sizes once cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), cudaMemcpyHostToDevice); cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), @@ -134,13 +138,6 @@ static void ovr_sparse_csc_host_streaming_impl( } } - // Pin only the host input arrays; outputs live on the device. - size_t total_nnz = (size_t)h_indptr[n_cols]; - HostRegisterGuard _pin_data(const_cast(h_data), - total_nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - total_nnz * sizeof(IndexT)); - cudaDeviceSynchronize(); int col = 0; @@ -156,7 +153,7 @@ static void ovr_sparse_csc_host_streaming_impl( int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), "OVR host CSC active batch nnz"); - // H2D: transfer sparse data for this column range (native dtype) + // H2D: this column range's sparse data (native dtype) if (batch_nnz > 0) { cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, (size_t)batch_nnz * sizeof(InT), @@ -166,7 +163,6 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpyHostToDevice, stream); } - // D2D: copy this batch's rebased offsets from the pre-uploaded buffer int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), cudaMemcpyDeviceToDevice, stream); @@ -178,7 +174,7 @@ static void ovr_sparse_csc_host_streaming_impl( sb_cols, n_groups, compute_nnz, tpb, smem_cast, cast_use_gmem, stream); - // CUB sort only stored nonzeros (float32 keys) + // Sort only stored nonzeros (float32 keys) if (batch_nnz > 0) { size_t temp = cub_temp_bytes; cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( @@ -189,14 +185,12 @@ static void ovr_sparse_csc_host_streaming_impl( "host CSC OVR segmented sort"); } - // Sparse rank kernel (stats already captured above) launch_ovr_sparse_rank( buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, n_rows, sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); - // D2D: scatter sub-batch results into caller's GPU buffers cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, @@ -263,8 +257,7 @@ static void ovr_sparse_csr_host_rowstream_impl( size_t budget = rmm_available_device_bytes(0.8); // ---- Phase 0: column histogram on the host, threaded by row range. Each - // worker counts its rows' columns into a private array (cache-resident, no - // false sharing), merged afterwards. No device transfer. ---- + // worker counts into a private array (no false sharing), merged after. ---- std::vector h_col_counts(n_cols, 0); { int n_workers = host_worker_count(); @@ -321,8 +314,8 @@ static void ovr_sparse_csr_host_rowstream_impl( size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); - // ---- Host gather staging (pinned for fast bulk H2D) + per-row cursor. The - // full CSR is NOT page-locked: the gather reads it on the CPU; only the + // ---- Host gather staging (pinned for fast bulk H2D) + per-row cursor. + // Full CSR is NOT page-locked: gather reads it on the CPU, only the // compacted per-batch slice crosses the bus. ---- size_t stage_nnz = max_batch_nnz ? max_batch_nnz : 1; std::vector h_gather_vals(stage_nnz); @@ -362,12 +355,11 @@ static void ovr_sparse_csr_host_rowstream_impl( rank_use_gmem ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; - // ---- One linear pass over the matrix, column-batched. The per-row cursor - // advances monotonically (indices sorted + batches ascending), so every - // nonzero is read on the host and transferred exactly once across all - // batches -- no whole-matrix re-streaming. The host gather is threaded: - // count each row's run (binary search), prefix-sum to per-row output - // offsets, then copy rows in parallel into disjoint staging ranges. ---- + // ---- One linear pass, column-batched. The per-row cursor advances + // monotonically (indices sorted + batches ascending), so each nonzero is + // read and transferred exactly once -- no whole-matrix re-streaming. Gather + // is threaded: count each row's run, prefix-sum to per-row output offsets, + // copy rows in parallel into disjoint staging ranges. ---- std::vector g_count(n_rows); int col = 0; for (int b = 0; b < n_batches; b++) { @@ -384,8 +376,7 @@ static void ovr_sparse_csr_host_rowstream_impl( (int)(std::lower_bound(lo, hi, (IndexT)col_end) - lo); } }); - // Prefix sum -> per-row output offsets (the gathered mini-CSR's row - // pointer). + // Prefix sum -> per-row output offsets (gathered mini-CSR row pointer). h_gather_indptr[0] = 0; for (int r = 0; r < n_rows; r++) h_gather_indptr[r + 1] = checked_int_span( @@ -413,8 +404,7 @@ static void ovr_sparse_csr_host_rowstream_impl( cudaMemcpy(write_pos, off, sb_cols * sizeof(int), cudaMemcpyHostToDevice); - // Bulk H2D of just this batch's compacted nonzeros (each transferred - // once over the matrix's lifetime) + the per-row split. + // Bulk H2D of this batch's compacted nonzeros (1x transfer). if (batch_nnz > 0) { cuda_check(cudaMemcpy(d_gather_vals, h_gather_vals.data(), (size_t)batch_nnz * sizeof(InT), @@ -428,7 +418,7 @@ static void ovr_sparse_csr_host_rowstream_impl( cudaMemcpy(d_gather_indptr, h_gather_indptr.data(), (n_rows + 1) * sizeof(int), cudaMemcpyHostToDevice); - // Scatter the gathered mini-CSR into the column-batch CSC accumulator. + // Scatter mini-CSR into the column-batch CSC accumulator. csr_scatter_to_csc_kernel <<<(n_rows + tpb - 1) / tpb, tpb>>>( d_gather_vals, d_gather_cols, d_gather_indptr, write_pos, @@ -819,7 +809,6 @@ static void ovr_sparse_csc_streaming_impl( if (n_cols < n_streams * sub_batch_cols) n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - // Find max nnz across any sub-batch for buffer sizing size_t max_nnz = 0; for (int col = 0; col < n_cols; col += sub_batch_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); @@ -827,7 +816,6 @@ static void ovr_sparse_csc_streaming_impl( if (nnz > max_nnz) max_nnz = nnz; } - // CUB temp size for max_nnz items size_t cub_temp_bytes = 0; if (max_nnz > 0) { int max_nnz_i32 = @@ -911,7 +899,6 @@ static void ovr_sparse_csc_streaming_impl( sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); - // Scatter results to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), buf.sub_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, @@ -1007,7 +994,7 @@ static void ovr_sparse_csr_streaming_impl( if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; } - // Upload all batch offsets to GPU in one shot (~20 KB) + // Upload all batch offsets in one H2D (~20 KB) int* d_all_offsets = pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); cudaMemcpy(d_all_offsets, h_all_offsets.data(), @@ -1091,23 +1078,22 @@ static void ovr_sparse_csr_streaming_impl( int batch_nnz = checked_int_span(h_batch_nnz[b], "OVR device CSR active batch nnz"); - // D2D copy pre-computed col_offsets for this batch int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), cudaMemcpyDeviceToDevice, stream); - // Initialize write_pos = col_offsets[0..sb_cols-1] (same D2D source) + // write_pos = col_offsets[0..sb_cols-1] (same D2D source) cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), cudaMemcpyDeviceToDevice, stream); if (batch_nnz > 0) { - // Scatter CSR → CSC layout for this sub-batch + // Scatter CSR → CSC for this sub-batch csr_scatter_to_csc_kernel<<>>( csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, buf.csc_row_idx, n_rows, col, col + sb_cols); CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); - // CUB sort only the nonzeros + // Sort only the nonzeros size_t temp = cub_temp_bytes; cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( buf.cub_temp, temp, buf.csc_vals, buf.keys_out, @@ -1124,7 +1110,6 @@ static void ovr_sparse_csr_streaming_impl( sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); - // Scatter results to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), buf.sub_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index a28d1c6c..90e66793 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -67,7 +67,7 @@ __global__ void rank_sums_sparse_ovr_kernel( int acc_stride; if (use_gmem) { - // Output rank_sums doubles as accumulator (pre-zeroed by caller). + // rank_sums doubles as accumulator (pre-zeroed by caller). grp_sums = rank_sums + (size_t)col; grp_nz_count = nz_count_scratch + (size_t)col; acc_stride = sb_cols; @@ -82,10 +82,9 @@ __global__ void rank_sums_sparse_ovr_kernel( __syncthreads(); } - // --- Find stored zero range: pos_start = first val > 0 --- + // pos_start = first index where sv[i] > 0 (stored zeros precede positives). __shared__ int sh_pos_start; if (threadIdx.x == 0) { - // Binary search: first index where sv[i] > 0.0 int lo = 0, hi = nnz_stored; while (lo < hi) { int mid = lo + ((hi - lo) >> 1); @@ -110,16 +109,16 @@ __global__ void rank_sums_sparse_ovr_kernel( // = n_implicit_zero + (a + b + 1) / 2 int offset_pos = n_implicit_zero; - // --- Count stored positive values per group --- + // Count stored positives per group. for (int i = pos_start + threadIdx.x; i < nnz_stored; i += blockDim.x) { int grp = group_codes[si[i]]; - if (grp < n_groups) { + if (grp >= 0 && grp < n_groups) { atomicAdd(&grp_nz_count[(size_t)grp * acc_stride], 1.0); } } __syncthreads(); - // --- Zero-rank contribution per group --- + // Analytic zero contribution: each group's zeros all get zero_avg_rank. for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { double n_zero_in_g = group_sizes[g] - grp_nz_count[(size_t)g * acc_stride]; @@ -127,7 +126,7 @@ __global__ void rank_sums_sparse_ovr_kernel( } __syncthreads(); - // --- Walk stored positives only and compute ranks --- + // Walk stored positives and compute tie-averaged ranks. int n_pos = nnz_stored - pos_start; int chunk = (n_pos + blockDim.x - 1) / blockDim.x; int my_start = pos_start + threadIdx.x * chunk; @@ -146,7 +145,7 @@ __global__ void rank_sums_sparse_ovr_kernel( int tie_global_start = i; if (i == my_start && i > 0 && sv[i - 1] == val) { - // Binary search for first occurrence + // tie spans into a prior chunk: find global tie start. int lo = pos_start, hi = i; while (lo < hi) { int mid = lo + ((hi - lo) >> 1); @@ -179,7 +178,7 @@ __global__ void rank_sums_sparse_ovr_kernel( for (int j = i; j < tie_local_end; ++j) { int grp = group_codes[si[j]]; - if (grp < n_groups) { + if (grp >= 0 && grp < n_groups) { atomicAdd(&grp_sums[(size_t)grp * acc_stride], avg_rank); } } @@ -203,7 +202,7 @@ __global__ void rank_sums_sparse_ovr_kernel( // Tie correction: warp + block reduction if (compute_tie_corr) { - // Zero tie group contribution (one thread only) + // Single zero tie block contributes once. if (threadIdx.x == 0 && total_zero > 1) { double tz = (double)total_zero; local_tie_sum += tz * tz * tz - tz; @@ -328,7 +327,7 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( data_f32_out[i] = (float)v_in; int row = (int)indices[i]; int g = group_codes[row]; - if (g < n_groups) { + if (g >= 0 && g < n_groups) { atomicAdd(&s_sum[g], v); if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); } @@ -367,7 +366,7 @@ __global__ void ovr_cast_and_accumulate_sparse_global_kernel( data_f32_out[i] = (float)v_in; int row = (int)indices[i]; int g = group_codes[row]; - if (g < n_groups) { + if (g >= 0 && g < n_groups) { atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); if (compute_nnz && v != 0.0) { atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 0ca84bfa..3ddaf067 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -77,15 +77,30 @@ def rank_genes_groups( Rank genes for characterizing groups using GPU acceleration. Log1p/log-normalized data is expected for biologically meaningful log fold - changes. Complex values are rejected. Sparse inputs with explicit negative - values fall back to the dense full-sort ranking path; dense inputs are - ranked directly and support any sign. + changes. Sparse inputs with explicit negative values fall back to the dense + full-sort ranking path; dense inputs are ranked directly and support any + sign. .. note:: **Dask support:** `'t-test'`, `'t-test_overestim_var'`, `'wilcoxon_binned'`, and `'logreg'` support Dask arrays. The `'wilcoxon'` method does not support Dask arrays. + .. note:: + **Wilcoxon ranking precision:** `'wilcoxon'` and `'wilcoxon_binned'` + rank values in float32 on every code path, while means and log fold + changes are computed in float64. This only diverges from Scanpy when the + **preprocessing itself ran in float64** — i.e. normalization/log1p + produced values carrying sub-float32 precision. If preprocessing was + done in float32 (the common case), the values are float32-exact and + ranking is bit-identical to Scanpy (~1e-13), even if they are afterward + stored as float64. For a fully float64 pipeline the rank-derived scores + and p-values still match Scanpy-on-float64 to ~1e-4 on log-normalized + data — below any significance threshold and changing no DE calls — + because the rank-sum normal approximation is insensitive to sub-float32 + tie jitter. If exact float64 ranking matters for your workflow, please + open an issue at https://github.com/scverse/rapids_singlecell/issues. + Parameters ---------- adata @@ -226,7 +241,6 @@ def rank_genes_groups( if key_added is None: key_added = "rank_genes_groups" - # Process mask_var: convert string to boolean array mask_var_array: NDArray[np.bool_] | None = None if mask_var is not None: if isinstance(mask_var, str): @@ -253,7 +267,6 @@ def rank_genes_groups( skip_empty_groups=skip_empty_groups, ) - # Determine n_genes_user n_genes_user = n_genes if n_genes_user is None or n_genes_user > test_obj.X.shape[1]: n_genes_user = test_obj.X.shape[1] diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index 72c1a32a..4991b9ce 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -13,7 +13,12 @@ from rapids_singlecell.get._aggregated import Aggregate from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import EPS, _reject_complex, _select_groups, _sparse_has_negative +from ._utils import ( + EPS, + _canonicalize_sparse, + _select_groups, + _sparse_has_negative, +) _RANK_SORT_MIN_ELEMENTS = 1_000_000 _RANK_SORT_MAX_WORKERS = 64 @@ -44,7 +49,6 @@ def __init__( pre_load: bool = False, skip_empty_groups: bool = False, ) -> None: - # Handle groups parameter if groups == "all" or groups is None: selected: list | None = None elif isinstance(groups, str | int): @@ -74,7 +78,6 @@ def __init__( skip_empty_groups=skip_empty_groups, ) - # Get data matrix if layer is not None: if use_raw is True: msg = "Cannot specify `layer` and have `use_raw=True`." @@ -94,7 +97,6 @@ def __init__( self.X = adata.X self.var_names = adata.var_names - # Apply mask_var to select subset of genes if mask_var is not None: self.X = self.X[:, mask_var] self.var_names = self.var_names[mask_var] @@ -105,7 +107,7 @@ def __init__( if reference != "rest": self.ireference = int(np.where(self.groups_order == str(reference))[0][0]) - # Set up expm1 function based on log base + # expm1 function depends on the log base used by log1p self.is_log1p = "log1p" in adata.uns base = adata.uns.get("log1p", {}).get("base") self._log1p_base = base @@ -114,7 +116,6 @@ def __init__( else: self.expm1_func = np.expm1 - # For basic stats self.comp_pts = comp_pts self.means: np.ndarray | None = None self.vars: np.ndarray | None = None @@ -163,7 +164,6 @@ def _basic_stats(self) -> None: """ n_genes = self.X.shape[1] - # Check if data is already on GPU try: _check_gpu_X(self.X, allow_dask=True) except TypeError: @@ -172,12 +172,11 @@ def _basic_stats(self) -> None: is_on_gpu = True if not is_on_gpu: - # Data not on GPU - defer to chunk-based computation + # Not on GPU: defer to chunk-based computation in the wilcoxon loop self._compute_stats_in_chunks = True self._init_stats_arrays(n_genes) return - # Data is on GPU - use Aggregate for fast computation self._compute_stats_in_chunks = False agg = Aggregate(groupby=self.labels.cat, data=self.X) @@ -204,7 +203,6 @@ def _basic_stats(self) -> None: sums = sums_all[order] sq_sums = sq_sums_all[order] - # Compute means and variances from raw sums (all on GPU) means = sums / n group_ss = sq_sums - n * means**2 vars_ = cp.maximum(group_ss / cp.maximum(n - 1, 1), 0) @@ -239,7 +237,6 @@ def _basic_stats(self) -> None: self.vars_rest = None self.pts_rest = None - # Transfer to CPU self.means = cp.asnumpy(means) self.vars = cp.asnumpy(vars_) self.pts = cp.asnumpy(pts) if pts is not None else None @@ -279,20 +276,17 @@ def _accumulate_chunk_stats_vs_rest( stream=cp.cuda.get_current_stream().ptr, ) - # Means chunk_means = group_sums / group_sizes_dev[:, None] self.means[:, start:stop] = cp.asnumpy(chunk_means) - # Variances (with Bessel correction) + # variance with Bessel correction chunk_vars = group_sum_sq / group_sizes_dev[:, None] - chunk_means**2 chunk_vars *= group_sizes_dev[:, None] / (group_sizes_dev[:, None] - 1) self.vars[:, start:stop] = cp.asnumpy(chunk_vars) - # Pts (fraction expressing) if self.comp_pts: self.pts[:, start:stop] = cp.asnumpy(group_nnz / group_sizes_dev[:, None]) - # Rest statistics if self.ireference is None: total_sum = block.sum(axis=0) total_sum_sq = (block**2).sum(axis=0) @@ -386,13 +380,16 @@ def compute_statistics( # The optimized sparse Wilcoxon paths inject implicit zeros analytically # as a tie at the column minimum (valid only for nonnegative data). # t-test/logreg are mean/variance/model-based and sign-agnostic. For the - # Wilcoxon methods we reject complex input and, when sparse data holds + # Wilcoxon methods we canonicalize and, when sparse data holds # negatives, fall back to the dense full-sort ranking (correct for any # sign) rather than erroring -- so e.g. signed sparse data still ranks # correctly, just via the dense path. self._sparse_negative_fallback = False if method in {"wilcoxon", "wilcoxon_binned"}: - _reject_complex(self.X) + # Canonicalize before the negative check: summing duplicates can + # change stored values (e.g. +a and -a -> 0), and the fast paths + # rank each stored nnz once, so they must see scanpy's summed view. + self.X = _canonicalize_sparse(self.X) self._sparse_negative_fallback = _sparse_has_negative(self.X) if self.pre_load or method in { "t-test", diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 2a1c8f4a..3814f662 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -17,18 +17,6 @@ MIN_GROUP_SIZE_WARNING = 25 -def _reject_complex(X) -> None: - """Reject complex expression values (unsupported by every rank method).""" - dtype = None - if sp.issparse(X) or cpsp.issparse(X): - dtype = np.dtype(X.data.dtype) - elif isinstance(X, np.ndarray | cp.ndarray): - dtype = np.dtype(X.dtype) - if dtype is not None and dtype.kind == "c": - msg = "rank_genes_groups does not support complex expression values." - raise TypeError(msg) - - def _sparse_has_negative(X) -> bool: """Whether X is a sparse matrix holding an explicit negative value. @@ -45,6 +33,23 @@ def _sparse_has_negative(X) -> bool: return False +def _canonicalize_sparse(X): + """Sum duplicate entries and sort indices of sparse ``X`` in place. + + The fast Wilcoxon paths rank each stored nonzero once, so non-canonical + input with duplicate ``(row, col)`` entries would diverge from scanpy, + which sums duplicates when it densifies. Canonicalizing keeps them in + agreement. A no-op for already-canonical or dense input. + """ + if ( + (sp.issparse(X) or cpsp.issparse(X)) + and getattr(X, "format", None) in {"csr", "csc"} + and not X.has_canonical_format + ): + X.sum_duplicates() # also sorts indices and sets the canonical flag + return X + + def _select_groups( labels: pd.Series, selected: list | None, @@ -124,7 +129,6 @@ def _select_groups( np.int64 ) - # Validate singlet groups invalid_groups = {str(selected[i]) for i in range(n_groups) if group_sizes[i] < 2} if invalid_groups: msg = ( diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 70316b61..b1fb9168 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -158,9 +158,7 @@ def _fill_basic_stats_from_accumulators( total_sums: cp.ndarray | None = None, total_nnz: cp.ndarray | None = None, ) -> None: - # Wilcoxon does not output per-group variance; vars are left zero (real - # group means/pts come from group_sums/group_nnz, which drive the lfc/pts - # output and the rank test). + # vars left zero: wilcoxon does not output per-group variance. n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] means = group_sums / n rg.means = cp.asnumpy(means) @@ -336,13 +334,11 @@ def _host_sparse_fn_and_arrays(module, base_name: str, X): ) raise TypeError(msg) - # Row/column indices always fit int32 (cells and genes are < 2^31); only the - # indptr (cumulative nnz) can need int64. Mirrors the rest of the sparse code. + # Indices fit int32 (cells/genes < 2^31); only indptr (cumulative nnz) needs int64. is_i64 = X.indptr.dtype == np.int64 - # The *_f64 binding only changes the host pointer dtype so float64 data can - # be passed without a host-side copy; it still ranks in float32 on-device - # (the kernels cast InT -> float before the segmented sort). See - # _device_sparse_arrays_f32 for why float32 ranking is the uniform design. + # The *_f64 binding only changes the host pointer dtype to accept float64 data + # without a host copy; it still ranks in float32 on-device (kernels cast InT -> + # float before the segmented sort). See _device_sparse_arrays_f32. suffix = "" if is_f64: suffix += "_f64" @@ -367,9 +363,15 @@ def _device_sparse_arrays_f32(X): Wilcoxon ranking sorts float32 keys on every path -- the sparse fast paths AND the dense fallback (``_get_dense_column_block_f32``); the CUB segmented sort is float-keyed throughout. Casting ``X.data`` to float32 here therefore - does not diverge from any float64 ranking path, because there is none. For - count data float32 is exact (integer values < 2**24) and scanpy parity holds - at 1e-13. float64 input is accepted only to spare the caller a pre-cast. + does not diverge from any float64 ranking path, because there is none. This + only loses precision when preprocessing ran in float64; float32-preprocessed + values (even if later stored as float64) are float32-exact, so ranking + matches scanpy bit-for-bit (~1e-13). For a fully float64 pipeline the + rank-derived scores/p-values match scanpy-on-float64 to ~1e-4 on + log-normalized data (below any significance threshold, no DE calls change), + while means and log fold changes are still computed in float64. See the + ``rank_genes_groups`` note on ranking precision. float64 input is accepted + to spare the caller a pre-cast. """ data_dtype = np.dtype(X.data.dtype) if data_dtype == np.float32 or data_dtype == np.float64: @@ -475,14 +477,13 @@ def wilcoxon( ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" _maybe_preload_host_dense(rg) - # Compute basic stats - uses Aggregate if on GPU, else defers to chunks + # Aggregate if on GPU, else defer to chunks. rg._basic_stats() X = rg.X n_cells, n_total_genes = rg.X.shape group_sizes = rg.group_sizes if rg.ireference is not None: - # Compare each group against a specific reference group return _wilcoxon_with_reference( rg, X, @@ -493,7 +494,6 @@ def wilcoxon( chunk_size=chunk_size, return_u_values=return_u_values, ) - # Compare each group against "rest" (all other cells) return _wilcoxon_vs_rest( rg, X, @@ -522,7 +522,6 @@ def _wilcoxon_vs_rest( """Wilcoxon test: each group vs rest of cells.""" n_groups = len(rg.groups_order) - # Warn for small groups for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: @@ -703,7 +702,6 @@ def _wilcoxon_vs_rest( chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) - # Accumulate results per group all_scores: dict[int, list] = {i: [] for i in range(n_groups)} all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} @@ -760,7 +758,6 @@ def _wilcoxon_vs_rest( all_scores[idx].append(scores_host[idx]) all_pvals[idx].append(p_host[idx]) - # Collect results per group return [ (gi, np.concatenate(all_scores[gi]), np.concatenate(all_pvals[gi])) for gi in range(n_groups) diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index fb18ae36..e65a89db 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -78,29 +78,6 @@ def test_rank_genes_groups_sparse_negative_values_fallback(method, fmt): ) -@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_dense", "cupy_csr"]) -def test_rank_genes_groups_complex_values_raise(fmt): - X = np.array( - [ - [1.0 + 0.0j, 0.0, 2.0], - [0.0, 1.0, 0.0], - [2.0, 0.0, 1.0], - [0.0, 3.0, 0.0], - ], - dtype=np.complex64, - ) - adata = sc.AnnData( - X=_to_format(X, fmt), - obs=pd.DataFrame( - {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} - ), - var=pd.DataFrame(index=["g0", "g1", "g2"]), - ) - - with pytest.raises(TypeError, match="complex expression values"): - rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) - - @pytest.mark.parametrize("layout", ["csr", "csc"]) @pytest.mark.parametrize("reference", ["rest", "1"]) def test_device_sparse_int64_indptr_matches_scanpy(layout, reference): From 8388d86c0f4331cd827c602d29246975f489bfb6 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 22 Jun 2026 15:18:33 +0200 Subject: [PATCH 22/36] start dedup --- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 83 +--- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 468 +++++------------- .../_cuda/wilcoxon/wilcoxon.cu | 20 +- .../_cuda/wilcoxon/wilcoxon_block_reduce.cuh | 27 + .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 26 + .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 24 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 24 +- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 10 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 58 +-- .../_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh | 72 +++ .../wilcoxon/wilcoxon_sparse_kernels.cuh | 142 ++---- 11 files changed, 355 insertions(+), 599 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 5e718d24..e7064ce8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -2,26 +2,8 @@ #include -__device__ __forceinline__ double wilcoxon_block_sum(double val, - double* warp_buf) { -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - val += __shfl_down_sync(0xffffffff, val, off); - int lane = threadIdx.x & 31; - int wid = threadIdx.x >> 5; - if (lane == 0) warp_buf[wid] = val; - __syncthreads(); - if (threadIdx.x < 32) { - double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) - ? warp_buf[threadIdx.x] - : 0.0; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - v += __shfl_down_sync(0xffffffff, v, off); - return v; - } - return 0.0; -} +#include "wilcoxon_block_reduce.cuh" +#include "wilcoxon_ovr_tie_walk.cuh" // Dense OVR rank kernel. sorted_vals/sorted_row_idx are F-order arrays from a // segmented SortPairs. One block per column; walks sorted tie runs and @@ -58,64 +40,11 @@ __global__ void rank_sums_from_sorted_kernel( int my_end = my_start + chunk; if (my_end > n_rows) my_end = n_rows; - double local_tie_sum = 0.0; int acc_stride = use_gmem ? n_cols : 1; - - int i = my_start; - while (i < my_end) { - double val = sv[i]; - - int tie_local_end = i + 1; - while (tie_local_end < my_end && sv[tie_local_end] == val) { - ++tie_local_end; - } - - int tie_global_start = i; - if (i == my_start && i > 0 && sv[i - 1] == val) { - int lo = 0; - int hi = i; - while (lo < hi) { - int mid = lo + ((hi - lo) >> 1); - if (sv[mid] < val) - lo = mid + 1; - else - hi = mid; - } - tie_global_start = lo; - } - - int tie_global_end = tie_local_end; - if (tie_local_end == my_end && tie_local_end < n_rows && - sv[tie_local_end] == val) { - int lo = tie_local_end; - int hi = n_rows - 1; - while (lo < hi) { - int mid = hi - ((hi - lo) >> 1); - if (sv[mid] > val) - hi = mid - 1; - else - lo = mid; - } - tie_global_end = lo + 1; - } - - int total_tie = tie_global_end - tie_global_start; - double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; - - for (int j = i; j < tie_local_end; ++j) { - int grp = group_codes[si[j]]; - if (grp >= 0 && grp < n_groups) { - atomicAdd(&grp_sums[grp * acc_stride], avg_rank); - } - } - - if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { - double t = (double)total_tie; - local_tie_sum += t * t * t - t; - } - - i = tie_local_end; - } + double local_tie_sum = ovr_walk_tie_runs( + sv, si, group_codes, grp_sums, acc_stride, n_groups, my_start, my_end, + /*seg_floor=*/0, /*seg_ceil=*/n_rows, /*rank_offset=*/0.0, + compute_tie_corr); __syncthreads(); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index 6383bf5d..b2c88820 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -2,31 +2,97 @@ #include +#include "wilcoxon_block_reduce.cuh" #include "wilcoxon_fast_common.cuh" // ============================================================================ -// Warp reduction helper (sum doubles across block via warp_buf) +// Bitonic sort of `n` floats in shared memory, ascending. `n` must be a power +// of two; pad the tail with +INF before calling. Grid-stride, so any blockDim +// works (covers both the LARGE runtime-sized and SMALL fixed-size paths). // ============================================================================ -__device__ __forceinline__ double block_reduce_sum(double val, - double* warp_buf) { -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - val += __shfl_down_sync(0xffffffff, val, off); - int lane = threadIdx.x & 31; - int wid = threadIdx.x >> 5; - if (lane == 0) warp_buf[wid] = val; - __syncthreads(); - if (threadIdx.x < 32) { - double v2 = (threadIdx.x < ((blockDim.x + 31) >> 5)) - ? warp_buf[threadIdx.x] - : 0.0; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - v2 += __shfl_down_sync(0xffffffff, v2, off); - return v2; // only lane 0 of warp 0 has the final result +__device__ __forceinline__ void bitonic_sort_smem(float* s, int n) { + for (int k = 2; k <= n; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = threadIdx.x; i < n; i += blockDim.x) { + int ixj = i ^ j; + if (ixj > i) { + bool asc = ((i & k) == 0); + float a = s[i], b = s[ixj]; + if (asc ? (a > b) : (a < b)) { + s[i] = b; + s[ixj] = a; + } + } + } + __syncthreads(); + } } - return 0.0; +} + +// ============================================================================ +// Sorted-array bounds over [lo, hi). lower: first index with arr[idx] >= v +// (count of elements < v). upper: first index with arr[idx] > v (count <= v). +// Pass an advanced `lo` to exploit per-thread-stride monotonicity. Work for +// both global and shared `arr`. +// ============================================================================ + +__device__ __forceinline__ int sorted_lower_bound(const float* arr, int lo, + int hi, float v) { + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (arr[m] < v) + lo = m + 1; + else + hi = m; + } + return lo; +} + +__device__ __forceinline__ int sorted_upper_bound(const float* arr, int lo, + int hi, float v) { + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (arr[m] <= v) + lo = m + 1; + else + hi = m; + } + return lo; +} + +// Mid-rank of `v` in the merged (ref, grp) arrays. Advances the four +// incremental bounds (pass 0,0,0,0 for a fresh per-element search) and reports +// the per-array equal counts for tie correction. +struct OvoRank { + double mid_rank; + int n_eq_ref; + int n_eq_grp; +}; + +__device__ __forceinline__ OvoRank ovo_mid_rank(const float* ref, int n_ref, + const float* grp, int n_grp, + float v, int& ref_lb, + int& ref_ub, int& grp_lb, + int& grp_ub) { + int n_lt_ref = sorted_lower_bound(ref, ref_lb, n_ref, v); + ref_lb = n_lt_ref; + ref_ub = sorted_upper_bound(ref, ref_ub > n_lt_ref ? ref_ub : n_lt_ref, + n_ref, v); + int n_eq_ref = ref_ub - n_lt_ref; + + int n_lt_grp = sorted_lower_bound(grp, grp_lb, n_grp, v); + grp_lb = n_lt_grp; + grp_ub = sorted_upper_bound(grp, grp_ub > n_lt_grp ? grp_ub : n_lt_grp, + n_grp, v); + int n_eq_grp = grp_ub - n_lt_grp; + + OvoRank r; + r.mid_rank = (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + r.n_eq_ref = n_eq_ref; + r.n_eq_grp = n_eq_grp; + return r; } // ============================================================================ @@ -50,39 +116,13 @@ __device__ __forceinline__ void compute_tie_correction_parallel( if (i == 0 || ref_col[i] != ref_col[i - 1]) { float v = ref_col[i]; - int lo = i + 1, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt_ref = lo - i; - - lo = grp_lb; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] < v) - lo = m + 1; - else - hi = m; - } - int lb = lo; - grp_lb = lb; + int cnt_ref = sorted_upper_bound(ref_col, i + 1, n_ref, v) - i; - lo = (grp_ub > lb) ? grp_ub : lb; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt_grp = lo - lb; - grp_ub = lo; + int lb = sorted_lower_bound(grp_col, grp_lb, n_grp, v); + grp_lb = lb; + grp_ub = sorted_upper_bound(grp_col, grp_ub > lb ? grp_ub : lb, + n_grp, v); + int cnt_grp = grp_ub - lb; int cnt = cnt_ref + cnt_grp; if (cnt > 1) { @@ -98,28 +138,12 @@ __device__ __forceinline__ void compute_tie_correction_parallel( if (i == 0 || grp_col[i] != grp_col[i - 1]) { float v = grp_col[i]; - int lo = ref_lb, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } + int lo = sorted_lower_bound(ref_col, ref_lb, n_ref, v); ref_lb = lo; if (lo >= n_ref || ref_col[lo] != v) { // Value absent from ref — count in grp only - lo = i + 1; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt = lo - i; + int cnt = sorted_upper_bound(grp_col, i + 1, n_grp, v) - i; if (cnt > 1) { double t = (double)cnt; local_tie += t * t * t - t; @@ -128,7 +152,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( } } - double tie_sum = block_reduce_sum(local_tie, warp_buf); + double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); if (threadIdx.x == 0) { int n = n_ref + n_grp; double dn = (double)n; @@ -177,63 +201,13 @@ __global__ void ovo_rank_huge_kernel( double local_sum = 0.0; for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - float v = grp_col[i]; - int lo, hi; - - lo = ref_lb; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; - ref_lb = n_lt_ref; - - lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_ref = lo - n_lt_ref; - ref_ub = lo; - - lo = grp_lb; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_grp = lo; - grp_lb = n_lt_grp; - - lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_grp = lo - n_lt_grp; - grp_ub = lo; - - local_sum += (double)(n_lt_ref + n_lt_grp) + - ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_col, n_grp, grp_col[i], + ref_lb, ref_ub, grp_lb, grp_ub); + local_sum += r.mid_rank; } __shared__ double warp_buf[32]; - double total = block_reduce_sum(local_sum, warp_buf); + double total = wilcoxon_block_sum(local_sum, warp_buf); if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; if (!compute_tie_corr) return; @@ -294,22 +268,7 @@ __global__ void ovo_rank_large_kernel( __syncthreads(); // Bitonic sort in shared memory - for (int k = 2; k <= large_padded; k <<= 1) { - for (int j = k >> 1; j > 0; j >>= 1) { - for (int i = threadIdx.x; i < large_padded; i += blockDim.x) { - int ixj = i ^ j; - if (ixj > i) { - bool asc = ((i & k) == 0); - float a = grp_smem[i], b = grp_smem[ixj]; - if (asc ? (a > b) : (a < b)) { - grp_smem[i] = b; - grp_smem[ixj] = a; - } - } - } - __syncthreads(); - } - } + bitonic_sort_smem(grp_smem, large_padded); // Binary search each sorted grp element against sorted ref; // incremental bounds (monotonic within each thread's stride) @@ -319,62 +278,12 @@ __global__ void ovo_rank_large_kernel( double local_sum = 0.0; for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - float v = grp_smem[i]; - int lo, hi; - - lo = ref_lb; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; - ref_lb = n_lt_ref; - - lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_ref = lo - n_lt_ref; - ref_ub = lo; - - lo = grp_lb; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_smem[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_grp = lo; - grp_lb = n_lt_grp; - - lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_smem[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_grp = lo - n_lt_grp; - grp_ub = lo; - - local_sum += (double)(n_lt_ref + n_lt_grp) + - ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_smem, n_grp, grp_smem[i], + ref_lb, ref_ub, grp_lb, grp_ub); + local_sum += r.mid_rank; } - double total = block_reduce_sum(local_sum, warp_buf); + double total = wilcoxon_block_sum(local_sum, warp_buf); if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; if (!compute_tie_corr) return; @@ -402,15 +311,7 @@ __global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { if (i == 0 || ref_col[i] != ref_col[i - 1]) { float v = ref_col[i]; - int lo = i + 1, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt = lo - i; + int cnt = sorted_upper_bound(ref_col, i + 1, n_ref, v) - i; if (cnt > 1) { double t = (double)cnt; local_tie += t * t * t - t; @@ -419,7 +320,7 @@ __global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, } __shared__ double warp_buf[32]; - double total = block_reduce_sum(local_tie, warp_buf); + double total = wilcoxon_block_sum(local_tie, warp_buf); if (threadIdx.x == 0) ref_tie_sums[col] = total; } @@ -449,21 +350,7 @@ __global__ void ovo_rank_small_kernel( } __syncthreads(); - for (int k = 2; k <= OVO_SMALL_MAX; k <<= 1) { - for (int j = k >> 1; j > 0; j >>= 1) { - int i = threadIdx.x; - int ixj = i ^ j; - if (i < OVO_SMALL_MAX && ixj > i) { - bool asc = ((i & k) == 0); - float a = grp_smem[i], b = grp_smem[ixj]; - if (asc ? (a > b) : (a < b)) { - grp_smem[i] = b; - grp_smem[ixj] = a; - } - } - __syncthreads(); - } - } + bitonic_sort_smem(grp_smem, OVO_SMALL_MAX); const float* ref_col = ref_sorted + (long long)col * n_ref; double local_sum = 0.0; @@ -471,69 +358,31 @@ __global__ void ovo_rank_small_kernel( if (threadIdx.x < n_grp) { float v = grp_smem[threadIdx.x]; - - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_ref = lo - n_lt_ref; - - lo = 0; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_smem[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_grp = lo; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_smem[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_grp = lo - n_lt_grp; - - local_sum += (double)(n_lt_ref + n_lt_grp) + - ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + int ref_lb = 0, ref_ub = 0, grp_lb = 0, grp_ub = 0; + OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_smem, n_grp, v, ref_lb, + ref_ub, grp_lb, grp_ub); + local_sum += r.mid_rank; if (compute_tie_corr && (threadIdx.x == 0 || v != grp_smem[threadIdx.x - 1])) { - double combined = (double)(n_eq_ref + n_eq_grp); + double combined = (double)(r.n_eq_ref + r.n_eq_grp); if (combined > 1.0) { local_tie_delta += combined * combined * combined - combined; } - if (n_eq_ref > 1) { - double cr = (double)n_eq_ref; + if (r.n_eq_ref > 1) { + double cr = (double)r.n_eq_ref; local_tie_delta -= cr * cr * cr - cr; } } } - double total = block_reduce_sum(local_sum, warp_buf); + double total = wilcoxon_block_sum(local_sum, warp_buf); if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; if (!compute_tie_corr) return; __syncthreads(); - double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + double tie_delta = wilcoxon_block_sum(local_tie_delta, warp_buf); if (threadIdx.x == 0) { int n = n_ref + n_grp; double dn = (double)n; @@ -584,25 +433,9 @@ __global__ void ovo_rank_medium_kernel( for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { float v = grp_smem[i]; - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; - - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_ref = lo - n_lt_ref; + int n_lt_ref = sorted_lower_bound(ref_col, 0, n_ref, v); + int n_eq_ref = + sorted_upper_bound(ref_col, n_lt_ref, n_ref, v) - n_lt_ref; int n_lt_grp = 0; int n_eq_grp = 0; @@ -633,13 +466,13 @@ __global__ void ovo_rank_medium_kernel( } } - double total = block_reduce_sum(local_sum, warp_buf); + double total = wilcoxon_block_sum(local_sum, warp_buf); if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; if (!compute_tie_corr) return; __syncthreads(); - double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + double tie_delta = wilcoxon_block_sum(local_tie_delta, warp_buf); if (threadIdx.x == 0) { int n = n_ref + n_grp; double dn = (double)n; @@ -674,15 +507,7 @@ __device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, bool is_first = in_ref_lane && ((i == 0) || (v != ref_col[i - 1])); int cnt_ref = 0; if (is_first) { - int lo = i + 1, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - cnt_ref = lo - i; + cnt_ref = sorted_upper_bound(ref_col, i + 1, n_ref, v) - i; } // Count in grp: look up how many lanes hold v_lane == v. All lanes @@ -717,14 +542,7 @@ __device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, bool first_in_grp = (lane == 0) || (v != v_prev); bool in_ref = false; if (first_in_grp) { - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } + int lo = sorted_lower_bound(ref_col, 0, n_ref, v); in_ref = (lo < n_ref) && (ref_col[lo] == v); } @@ -777,24 +595,9 @@ __device__ __forceinline__ double warp_tie_delta(const float* ref_col, } if (first_in_grp) { - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int ref_lb = lo; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt_ref = lo - ref_lb; + int ref_lb = sorted_lower_bound(ref_col, 0, n_ref, v); + int cnt_ref = + sorted_upper_bound(ref_col, ref_lb, n_ref, v) - ref_lb; double combined = (double)(cnt_ref + cnt_grp); if (combined > 1.0) { @@ -882,24 +685,9 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, if (lane < n_grp) { float v = x; - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_ref = lo - n_lt_ref; + int n_lt_ref = sorted_lower_bound(ref_col, 0, n_ref, v); + int n_eq_ref = + sorted_upper_bound(ref_col, n_lt_ref, n_ref, v) - n_lt_ref; // In-group counts: in the sorted warp-register x, count lanes < this // one that hold strictly less, and lanes with equal value. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index c1c30268..8dca8b00 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -85,12 +85,10 @@ static void launch_ovr_rank_dense_streaming( CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); const float* keys_in = block + (size_t)col * n_rows; - size_t temp = cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, - buf.vals_out, sb_items, sb_cols, buf.seg_offsets, - buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream), - "dense OVR segmented sort"); + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, stream, "dense OVR segmented sort"); if (use_gmem) { cuda_check(cudaMemsetAsync( @@ -260,12 +258,10 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( const float* ref_sub = ref_data + (size_t)col * n_ref; const float* grp_sub = grp_data + (size_t)col * n_all_grp; upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); - size_t ref_temp = ref_cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( - buf.ref_cub_temp, ref_temp, ref_sub, buf.ref_sorted, - sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream), - "dense OVO ref segmented sort"); + cub_segmented_sortkeys(buf.ref_cub_temp, ref_cub_temp_bytes, ref_sub, + buf.ref_sorted, sb_ref_items_actual, sb_cols, + buf.ref_seg_offsets, buf.ref_seg_offsets + 1, + stream, "dense OVO ref segmented sort"); ref_sub = buf.ref_sorted; OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh new file mode 100644 index 00000000..a11c20e8 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh @@ -0,0 +1,27 @@ +#pragma once + +#include + +// Block-wide sum of `val` across all threads. `warp_buf` is shared scratch +// holding one double per warp (>= ceil(blockDim.x / 32) <= 32). Result is +// returned on thread 0 (lane 0 of warp 0); other threads get 0.0. +__device__ __forceinline__ double wilcoxon_block_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + return v; + } + return 0.0; +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index da1aa18c..f73fc918 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -112,6 +112,32 @@ static inline size_t cub_segmented_sortpairs_temp_bytes(int num_items, return bytes; } +// Launch wrappers for the queries above. begin/end offset arrays may be +// contiguous (off, off + 1) or distinct (starts, ends). +static inline void cub_segmented_sortkeys( + void* d_temp, size_t temp_bytes, const float* keys_in, float* keys_out, + int num_items, int num_segments, const int* begin_offsets, + const int* end_offsets, cudaStream_t stream, const char* what) { + cuda_check( + cub::DeviceSegmentedRadixSort::SortKeys( + d_temp, temp_bytes, keys_in, keys_out, num_items, num_segments, + begin_offsets, end_offsets, BEGIN_BIT, END_BIT, stream), + what); +} + +template +static inline void cub_segmented_sortpairs( + void* d_temp, size_t temp_bytes, const float* keys_in, float* keys_out, + const ValT* vals_in, ValT* vals_out, int num_items, int num_segments, + const int* begin_offsets, const int* end_offsets, cudaStream_t stream, + const char* what) { + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + d_temp, temp_bytes, keys_in, keys_out, vals_in, vals_out, + num_items, num_segments, begin_offsets, end_offsets, + BEGIN_BIT, END_BIT, stream), + what); +} + // Universal CUDA static per-block shared-memory floor; safe fallback if the // device query fails. constexpr size_t WILCOXON_FALLBACK_SMEM_PER_BLOCK = 48 * 1024; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 71c4b64e..cc3e1b55 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -164,13 +164,11 @@ static void ovo_streaming_csr_impl( size_t ref_cub_bytes = cub_segmented_sortkeys_temp_bytes(cache_ref_items_i32, cache_cols); ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); - size_t ref_temp = ref_cub_bytes; - cuda_check( - cub::DeviceSegmentedRadixSort::SortKeys( - ref_cub_temp_buf.data(), ref_temp, d_ref_dense, d_ref_sorted, - cache_ref_items_i32, cache_cols, d_ref_seg_offsets, - d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT, ref_stream), - "device CSR OVO ref segmented sort"); + cub_segmented_sortkeys(ref_cub_temp_buf.data(), ref_cub_bytes, + d_ref_dense, d_ref_sorted, cache_ref_items_i32, + cache_cols, d_ref_seg_offsets, + d_ref_seg_offsets + 1, ref_stream, + "device CSR OVO ref segmented sort"); cuda_check(cudaStreamSynchronize(ref_stream), "device CSR OVO ref sort sync"); @@ -375,14 +373,10 @@ static void ovo_streaming_csc_impl( n_ref, col); CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); - { - size_t temp = cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, - sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream), - "device CSC OVO ref segmented sort"); - } + cub_segmented_sortkeys(buf.cub_temp, cub_temp_bytes, buf.ref_dense, + buf.ref_sorted, sb_ref_items_actual, sb_cols, + buf.ref_seg_offsets, buf.ref_seg_offsets + 1, + stream, "device CSC OVO ref segmented sort"); cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), stream); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 218d075c..3b03a2e6 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -246,14 +246,10 @@ static void ovo_streaming_csc_host_impl( d_ref_row_map, buf.ref_dense, n_ref, 0); CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); - { - size_t temp = cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, - sb_ref_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream), - "host CSC OVO ref segmented sort"); - } + cub_segmented_sortkeys(buf.cub_temp, cub_temp_bytes, buf.ref_dense, + buf.ref_sorted, sb_ref_actual, sb_cols, + buf.ref_seg_offsets, buf.ref_seg_offsets + 1, + stream, "host CSC OVO ref segmented sort"); // Extract grp from CSC via row_map cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), @@ -578,13 +574,11 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR( csr_extract_dense_identity_rows_unsorted_kernel); upload_linear_offsets(d_ref_seg, cc, n_ref, ref_stream); - size_t temp = ref_cub_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( - cub_temp_buf.data(), temp, d_ref_dense, - d_ref_sorted + (size_t)cs * (size_t)n_ref, - (int)chunk_items, cc, d_ref_seg, d_ref_seg + 1, - BEGIN_BIT, END_BIT, ref_stream), - "host CSR OVO ref segmented sort"); + cub_segmented_sortkeys( + cub_temp_buf.data(), ref_cub_bytes, d_ref_dense, + d_ref_sorted + (size_t)cs * (size_t)n_ref, (int)chunk_items, cc, + d_ref_seg, d_ref_seg + 1, ref_stream, + "host CSR OVO ref segmented sort"); } cuda_check(cudaStreamSynchronize(ref_stream), "host CSR OVO ref sort sync"); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index 84ed7fab..a9042fed 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -255,12 +255,10 @@ static inline void ovo_dispatch_tiers( n_all_grp, n_sort_groups, sb_cols); CUDA_CHECK_LAST_ERROR(build_huge_seg_offsets_kernel); - size_t temp = grp_cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( - sc.grp_cub_temp, temp, grp_dense, sc.grp_sorted, - sb_grp_items_actual, sb_grp_seg, sc.grp_seg_offsets, - sc.grp_seg_ends, BEGIN_BIT, END_BIT, stream), - "OVO huge-tier group segmented sort"); + cub_segmented_sortkeys(sc.grp_cub_temp, grp_cub_temp_bytes, grp_dense, + sc.grp_sorted, sb_grp_items_actual, sb_grp_seg, + sc.grp_seg_offsets, sc.grp_seg_ends, stream, + "OVO huge-tier group segmented sort"); dim3 grid(sb_cols, n_groups); ovo_rank_huge_kernel<<>>( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 6ee17a0a..866bf03c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -176,13 +176,11 @@ static void ovr_sparse_csc_host_streaming_impl( // Sort only stored nonzeros (float32 keys) if (batch_nnz > 0) { - size_t temp = cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.d_sparse_data_f32, - buf.keys_out, buf.d_sparse_indices, buf.vals_out, - batch_nnz, sb_cols, buf.d_seg_offsets, - buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, stream), - "host CSC OVR segmented sort"); + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, buf.d_sparse_data_f32, + buf.keys_out, buf.d_sparse_indices, buf.vals_out, batch_nnz, + sb_cols, buf.d_seg_offsets, buf.d_seg_offsets + 1, stream, + "host CSC OVR segmented sort"); } launch_ovr_sparse_rank( @@ -430,12 +428,10 @@ static void ovr_sparse_csr_host_rowstream_impl( d_group_codes, sub_group_sums, sub_group_nnz, sb_cols, n_groups, compute_nnz, tpb, smem_cast, cast_use_gmem, 0); if (batch_nnz > 0) { - size_t temp = cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( - cub_temp, temp, csc_vals_f32, keys_out, csc_row_idx, - vals_out, batch_nnz, sb_cols, col_offsets, - col_offsets + 1, BEGIN_BIT, END_BIT), - "rowstream segmented sort"); + cub_segmented_sortpairs(cub_temp, cub_temp_bytes, csc_vals_f32, + keys_out, csc_row_idx, vals_out, batch_nnz, + sb_cols, col_offsets, col_offsets + 1, 0, + "rowstream segmented sort"); } launch_ovr_sparse_rank( keys_out, vals_out, col_offsets, d_group_codes, d_group_sizes, @@ -727,13 +723,11 @@ static void ovr_sparse_csr_host_streaming_impl( cast_use_gmem, stream); if (batch_nnz > 0) { - size_t temp = cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.csc_vals_f32, buf.keys_out, - buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, - buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, - END_BIT, stream), - "host CSR OVR segmented sort"); + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, buf.csc_vals_f32, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, stream, + "host CSR OVR segmented sort"); } launch_ovr_sparse_rank( @@ -883,13 +877,11 @@ static void ovr_sparse_csc_streaming_impl( // Sort only stored values (keys=data, vals=row_indices) if (batch_nnz > 0) { - size_t temp = cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, csc_data + ptr_start, - buf.keys_out, csc_indices + ptr_start, buf.vals_out, - batch_nnz, sb_cols, buf.seg_offsets, - buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream), - "device CSC OVR segmented sort"); + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, csc_data + ptr_start, + buf.keys_out, csc_indices + ptr_start, buf.vals_out, batch_nnz, + sb_cols, buf.seg_offsets, buf.seg_offsets + 1, stream, + "device CSC OVR segmented sort"); } // Sparse rank kernel (handles implicit zeros analytically) @@ -1094,13 +1086,11 @@ static void ovr_sparse_csr_streaming_impl( CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); // Sort only the nonzeros - size_t temp = cub_temp_bytes; - cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.csc_vals, buf.keys_out, - buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, - buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, - END_BIT, stream), - "device CSR OVR segmented sort"); + cub_segmented_sortpairs(buf.cub_temp, cub_temp_bytes, buf.csc_vals, + buf.keys_out, buf.csc_row_idx, buf.vals_out, + batch_nnz, sb_cols, buf.col_offsets, + buf.col_offsets + 1, stream, + "device CSR OVR segmented sort"); } // Sparse rank kernel (handles implicit zeros analytically) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh new file mode 100644 index 00000000..09398f30 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh @@ -0,0 +1,72 @@ +#pragma once + +#include + +// Walk this thread's chunk [my_start, my_end) of a sorted column, accumulating +// tie-averaged ranks into grp_sums (atomic, strided by acc_stride). Ties that +// straddle a chunk boundary are expanded to their global extent within +// [seg_floor, seg_ceil) by binary search. `rank_offset` shifts every rank (the +// sparse path uses it to account for implicit leading zeros). Returns this +// thread's tie-correction sum (sum of t^3 - t over tie blocks it owns). +template +__device__ __forceinline__ double ovr_walk_tie_runs( + const float* sv, const IndexT* si, const int* group_codes, double* grp_sums, + int acc_stride, int n_groups, int my_start, int my_end, int seg_floor, + int seg_ceil, double rank_offset, bool compute_tie_corr) { + double local_tie_sum = 0.0; + int i = my_start; + while (i < my_end) { + float val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > seg_floor && sv[i - 1] == val) { + // tie spans into a prior chunk: find global tie start. + int lo = seg_floor, hi = i; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < seg_ceil && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = seg_ceil - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + double avg_rank = + rank_offset + (double)(tie_global_start + tie_global_end + 1) / 2.0; + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp >= 0 && grp < n_groups) { + atomicAdd(&grp_sums[(size_t)grp * acc_stride], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + return local_tie_sum; +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 90e66793..14d3ef0c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -2,6 +2,9 @@ #include +#include "wilcoxon_block_reduce.cuh" +#include "wilcoxon_ovr_tie_walk.cuh" + /** * Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. * @@ -133,63 +136,10 @@ __global__ void rank_sums_sparse_ovr_kernel( int my_end = my_start + chunk; if (my_end > nnz_stored) my_end = nnz_stored; - double local_tie_sum = 0.0; - - int i = my_start; - while (i < my_end) { - float val = sv[i]; - - int tie_local_end = i + 1; - while (tie_local_end < my_end && sv[tie_local_end] == val) - ++tie_local_end; - - int tie_global_start = i; - if (i == my_start && i > 0 && sv[i - 1] == val) { - // tie spans into a prior chunk: find global tie start. - int lo = pos_start, hi = i; - while (lo < hi) { - int mid = lo + ((hi - lo) >> 1); - if (sv[mid] < val) - lo = mid + 1; - else - hi = mid; - } - tie_global_start = lo; - } - - int tie_global_end = tie_local_end; - if (tie_local_end == my_end && tie_local_end < nnz_stored && - sv[tie_local_end] == val) { - int lo = tie_local_end, hi = nnz_stored - 1; - while (lo < hi) { - int mid = hi - ((hi - lo) >> 1); - if (sv[mid] > val) - hi = mid - 1; - else - lo = mid; - } - tie_global_end = lo + 1; - } - - int total_tie = tie_global_end - tie_global_start; - - double avg_rank = (double)offset_pos + - (double)(tie_global_start + tie_global_end + 1) / 2.0; - - for (int j = i; j < tie_local_end; ++j) { - int grp = group_codes[si[j]]; - if (grp >= 0 && grp < n_groups) { - atomicAdd(&grp_sums[(size_t)grp * acc_stride], avg_rank); - } - } - - if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { - double t = (double)total_tie; - local_tie_sum += t * t * t - t; - } - - i = tie_local_end; - } + double local_tie_sum = ovr_walk_tie_runs( + sv, si, group_codes, grp_sums, acc_stride, n_groups, my_start, my_end, + /*seg_floor=*/pos_start, /*seg_ceil=*/nnz_stored, + /*rank_offset=*/(double)offset_pos, compute_tie_corr); __syncthreads(); @@ -213,25 +163,11 @@ __global__ void rank_sums_sparse_ovr_kernel( int warp_buf_off = use_gmem ? 0 : 2 * n_groups; double* warp_buf = smem + warp_buf_off; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); - int lane = threadIdx.x & 31; - int wid = threadIdx.x >> 5; - if (lane == 0) warp_buf[wid] = local_tie_sum; - __syncthreads(); - if (threadIdx.x < 32) { - double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) - ? warp_buf[threadIdx.x] - : 0.0; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - v += __shfl_down_sync(0xffffffff, v, off); - if (threadIdx.x == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - tie_corr[col] = (denom > 0.0) ? (1.0 - v / denom) : 1.0; - } + double v = wilcoxon_block_sum(local_tie_sum, warp_buf); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - v / denom) : 1.0; } } } @@ -284,6 +220,29 @@ static size_t cast_accumulate_smem_config(int n_groups, bool compute_nnz, return 0; } +// Shared cast+accumulate loop for the two sparse-OVR stats kernels. Casts each +// stored value to f32 (data_f32_out) and atomically accumulates per-group sums +// (and nonzero counts) into sums/nnz, strided by acc_stride (1 for a per-block +// smem buffer, sb_cols for the global row-major layout). +template +__device__ __forceinline__ void accumulate_group_stats( + const InT* data_in, float* data_f32_out, const IndexT* indices, + int seg_start, int seg_end, const int* group_codes, double* sums, + double* nnz, int acc_stride, int n_groups, bool compute_nnz) { + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g >= 0 && g < n_groups) { + atomicAdd(&sums[(size_t)g * acc_stride], v); + if (compute_nnz && v != 0.0) + atomicAdd(&nnz[(size_t)g * acc_stride], 1.0); + } + } +} + /** * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. * @@ -321,17 +280,9 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( } __syncthreads(); - for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { - InT v_in = data_in[i]; - double v = (double)v_in; - data_f32_out[i] = (float)v_in; - int row = (int)indices[i]; - int g = group_codes[row]; - if (g >= 0 && g < n_groups) { - atomicAdd(&s_sum[g], v); - if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); - } - } + accumulate_group_stats( + data_in, data_f32_out, indices, seg_start, seg_end, group_codes, s_sum, + s_nnz, /*acc_stride=*/1, n_groups, compute_nnz); __syncthreads(); for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { @@ -360,19 +311,10 @@ __global__ void ovr_cast_and_accumulate_sparse_global_kernel( int seg_start = col_seg_offsets[col]; int seg_end = col_seg_offsets[col + 1]; - for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { - InT v_in = data_in[i]; - double v = (double)v_in; - data_f32_out[i] = (float)v_in; - int row = (int)indices[i]; - int g = group_codes[row]; - if (g >= 0 && g < n_groups) { - atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); - if (compute_nnz && v != 0.0) { - atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); - } - } - } + accumulate_group_stats( + data_in, data_f32_out, indices, seg_start, seg_end, group_codes, + group_sums + col, group_nnz + col, + /*acc_stride=*/sb_cols, n_groups, compute_nnz); } template From 9470896bf8e66a39386ca5863e95232ab68c284a Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 22 Jun 2026 15:33:21 +0200 Subject: [PATCH 23/36] more dedup --- .../_cuda/wilcoxon/wilcoxon.cu | 18 ++--------- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 11 +++++++ .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 16 ++-------- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 16 ++-------- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 32 +++---------------- 5 files changed, 21 insertions(+), 72 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 8dca8b00..5b183b1f 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -119,14 +119,7 @@ static void launch_ovr_rank_dense_streaming( ++batch_idx; } - for (int s = 0; s < n_streams; ++s) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) { - throw std::runtime_error( - std::string("CUDA error in dense OVR streaming rank: ") + - cudaGetErrorString(err)); - } - } + sync_streams(streams, "dense OVR streaming rank"); } static void launch_ovo_rank_dense_tiered_unsorted_ref( @@ -292,14 +285,7 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( ++batch_idx; } - for (int s = 0; s < n_streams; ++s) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) { - throw std::runtime_error( - std::string("CUDA error in dense OVO tiered rank: ") + - cudaGetErrorString(err)); - } - } + sync_streams(streams, "dense OVO tiered rank"); } template diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index f73fc918..44726c13 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -337,6 +337,17 @@ struct ScopedCudaStreams { ScopedCudaStreams& operator=(const ScopedCudaStreams&) = delete; }; +// Drain every stream, surfacing the first async error with a context label. +static inline void sync_streams(const ScopedCudaStreams& streams, + const char* what) { + for (int i = 0; i < streams.size(); ++i) { + cudaError_t err = cudaStreamSynchronize(streams[i]); + if (err != cudaSuccess) + throw std::runtime_error(std::string("CUDA error in ") + what + + ": " + cudaGetErrorString(err)); + } +} + struct ScopedCudaEvent { cudaEvent_t event = nullptr; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index cc3e1b55..419860d6 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -220,13 +220,7 @@ static void ovo_streaming_csr_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in OVO device CSR streaming: ") + - cudaGetErrorString(err)); - } + sync_streams(streams, "OVO device CSR streaming"); } } @@ -409,11 +403,5 @@ static void ovo_streaming_csc_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in OVO device CSC streaming: ") + - cudaGetErrorString(err)); - } + sync_streams(streams, "OVO device CSC streaming"); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 3b03a2e6..1703c036 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -295,13 +295,7 @@ static void ovo_streaming_csc_host_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } + sync_streams(streams, "wilcoxon streaming"); } /** @@ -774,11 +768,5 @@ static void ovo_streaming_csr_host_impl( } } - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in ovo csr host streaming: ") + - cudaGetErrorString(err)); - } + sync_streams(streams, "ovo csr host streaming"); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 866bf03c..65bbd69a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -213,13 +213,7 @@ static void ovr_sparse_csc_host_streaming_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in sparse host CSC streaming: ") + - cudaGetErrorString(err)); - } + sync_streams(streams, "sparse host CSC streaming"); } // ============================================================================ @@ -759,13 +753,7 @@ static void ovr_sparse_csr_host_streaming_impl( col += sb_cols; } - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in sparse host CSR streaming: ") + - cudaGetErrorString(err)); - } + sync_streams(streams, "sparse host CSR streaming"); } // ============================================================================ @@ -905,13 +893,7 @@ static void ovr_sparse_csc_streaming_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in sparse ovr streaming: ") + - cudaGetErrorString(err)); - } + sync_streams(streams, "sparse ovr streaming"); } // ============================================================================ @@ -1113,11 +1095,5 @@ static void ovr_sparse_csr_streaming_impl( col += sb_cols; } - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in sparse CSR ovr streaming: ") + - cudaGetErrorString(err)); - } + sync_streams(streams, "sparse CSR ovr streaming"); } From 36a5de4bc23393ebd5cc851b0fc2fd876d36eee6 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 22 Jun 2026 15:55:56 +0200 Subject: [PATCH 24/36] make even smaller --- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 7 +- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 45 +++------- .../_cuda/wilcoxon/wilcoxon_block_reduce.cuh | 26 ++++-- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 57 +++++++++++++ .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 8 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 64 ++++---------- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 83 ++++++------------- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 6 +- 8 files changed, 132 insertions(+), 164 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index e7064ce8..db5f492c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -58,10 +58,7 @@ __global__ void rank_sums_from_sorted_kernel( int warp_buf_off = use_gmem ? 0 : n_groups; double* warp_buf = smem + warp_buf_off; double tie_sum = wilcoxon_block_sum(local_tie_sum, warp_buf); - if (threadIdx.x == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - tie_corr[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } + if (threadIdx.x == 0) + tie_corr[col] = finalize_tie_corr(n_rows, tie_sum); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index b2c88820..65916525 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -153,12 +153,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( } double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); - if (threadIdx.x == 0) { - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - *out = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } + if (threadIdx.x == 0) *out = finalize_tie_corr(n_ref + n_grp, tie_sum); } // ============================================================================ @@ -383,14 +378,9 @@ __global__ void ovo_rank_small_kernel( __syncthreads(); double tie_delta = wilcoxon_block_sum(local_tie_delta, warp_buf); - if (threadIdx.x == 0) { - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - double tie_sum = ref_tie_sums[col] + tie_delta; + if (threadIdx.x == 0) tie_corr[grp * n_cols + col] = - (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } + finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie_delta); } // ============================================================================ @@ -473,14 +463,9 @@ __global__ void ovo_rank_medium_kernel( __syncthreads(); double tie_delta = wilcoxon_block_sum(local_tie_delta, warp_buf); - if (threadIdx.x == 0) { - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - double tie_sum = ref_tie_sums[col] + tie_delta; + if (threadIdx.x == 0) tie_corr[grp * n_cols + col] = - (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } + finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie_delta); } // ============================================================================ @@ -565,9 +550,7 @@ __device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, } } -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - local_tie += __shfl_down_sync(0xffffffff, local_tie, off); + local_tie = warp_reduce_sum(local_tie); return local_tie; // meaningful on lane 0. } @@ -610,9 +593,7 @@ __device__ __forceinline__ double warp_tie_delta(const float* ref_col, } } -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - local_delta += __shfl_down_sync(0xffffffff, local_delta, off); + local_delta = warp_reduce_sum(local_delta); return local_delta; // meaningful on lane 0. } @@ -715,9 +696,7 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, } // Warp reduce. -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + local_sum = warp_reduce_sum(local_sum); if (lane == 0) rank_sums[grp * n_cols + col] = local_sum; if (!compute_tie_corr) return; @@ -729,11 +708,7 @@ __global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, } else { tie_sum = warp_tie_sum(ref_col, n_ref, x, n_grp, active_mask); } - if (lane == 0) { - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; + if (lane == 0) tie_corr[grp * n_cols + col] = - (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } + finalize_tie_corr(n_ref + n_grp, tie_sum); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh index a11c20e8..d92238f2 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh @@ -2,14 +2,20 @@ #include +// Sum `v` across the 32 lanes of a warp via shuffle-down; result on lane 0. +__device__ __forceinline__ double warp_reduce_sum(double v) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + return v; +} + // Block-wide sum of `val` across all threads. `warp_buf` is shared scratch // holding one double per warp (>= ceil(blockDim.x / 32) <= 32). Result is // returned on thread 0 (lane 0 of warp 0); other threads get 0.0. __device__ __forceinline__ double wilcoxon_block_sum(double val, double* warp_buf) { -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - val += __shfl_down_sync(0xffffffff, val, off); + val = warp_reduce_sum(val); int lane = threadIdx.x & 31; int wid = threadIdx.x >> 5; if (lane == 0) warp_buf[wid] = val; @@ -18,10 +24,16 @@ __device__ __forceinline__ double wilcoxon_block_sum(double val, double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) ? warp_buf[threadIdx.x] : 0.0; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - v += __shfl_down_sync(0xffffffff, v, off); - return v; + return warp_reduce_sum(v); } return 0.0; } + +// Final tie-correction factor: 1 - sum(t^3 - t) / (n^3 - n), or 1.0 when the +// ranking population n_total is too small for a correction. +__device__ __forceinline__ double finalize_tie_corr(int n_total, + double tie_sum) { + double dn = (double)n_total; + double denom = dn * dn * dn - dn; + return (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 44726c13..4a659184 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -65,6 +65,35 @@ constexpr int END_BIT = 32; constexpr int UTIL_BLOCK_SIZE = 256; // Scratch slots for warp-level reduction (one slot per warp, 32 warps max). constexpr int WARP_REDUCE_BUF = 32; + +// Stream-count clamps shared by the streaming impls: never use more streams +// than there are column batches, nor more than the per-stream memory budget +// allows. +static inline int clamp_streams_by_cols(int n_cols, int sub_batch_cols) { + int n = N_STREAMS; + if (n_cols < n * sub_batch_cols) + n = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + return n; +} + +static inline int clamp_streams_by_budget(int n_streams, + size_t per_stream_bytes, + size_t budget) { + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + return n_streams; +} + +// Scatter a [rows, sb_cols] device sub-batch block (row-major doubles, source +// row stride sb_cols) into `dst` whose row stride is n_cols. `dst` must already +// point at the destination column offset (e.g. out + col). +static inline void scatter_cols_2d(double* dst, const double* src, int rows, + int n_cols, int sb_cols, + cudaStream_t stream) { + cudaMemcpy2DAsync(dst, n_cols * sizeof(double), src, + sb_cols * sizeof(double), sb_cols * sizeof(double), rows, + cudaMemcpyDeviceToDevice, stream); +} // WARP band: warp-per-(col,group) fused kernel. Each warp sorts+ranks one // pair entirely in registers (warp-shuffle bitonic, no smem, no __syncthreads). // Blocks pack 8 warps to amortise launch overhead. Fast route for @@ -192,6 +221,34 @@ static inline int checked_int_product(size_t a, size_t b, const char* context) { return (int)(a * b); } +// Precompute per-batch CSC column offsets rebased to each batch's ptr_start, +// laid out [n_batches][sub_batch_cols+1], and upload once. Returns the device +// buffer (allocated from `pool`). Avoids a per-batch H2D from a transient host +// buffer in the CSC host streaming impls. +template +static inline int* precompute_csc_batch_offsets(const IndptrT* h_indptr, + int n_cols, int sub_batch_cols, + int n_batches, + RmmScratchPool& pool, + const char* what) { + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int rem = n_cols - col_start; + int sb = (sub_batch_cols < rem) ? sub_batch_cols : rem; + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = checked_int_span( + (size_t)(h_indptr[col_start + i] - ptr_start), what); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + return d_all_offsets; +} + // Largest per-batch nonzero count we let a column batch reach. A batch is // sorted in a single CUB segmented call (int32 item count) and addressed with // int offsets, so it must stay below INT_MAX with margin. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 419860d6..608354a8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -42,9 +42,7 @@ static void ovo_streaming_csr_impl( n_sort_groups = (int)h_sort_group_ids.size(); } - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; @@ -265,9 +263,7 @@ static void ovo_streaming_csc_impl( n_sort_groups = (int)h_sort_group_ids.size(); } - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 1703c036..f0c120b5 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -43,9 +43,7 @@ static void ovo_streaming_csc_host_impl( n_sort_groups = (int)h_sort_group_ids.size(); } - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; @@ -88,8 +86,7 @@ static void ovo_streaming_csc_host_impl( sizeof(double) + cub_temp_bytes; size_t budget = rmm_available_device_bytes(0.8); - while (n_streams > 1 && (size_t)n_streams * per_stream > budget) - n_streams--; + n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); } // pool first: streams drain before it frees their scratch (see guard doc). @@ -104,22 +101,9 @@ static void ovo_streaming_csc_host_impl( ScopedCudaStreams streams(n_streams, cudaStreamDefault); int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - for (int b = 0; b < n_batches; b++) { - int col_start = b * sub_batch_cols; - int sb = std::min(sub_batch_cols, n_cols - col_start); - IndptrT ptr_start = h_indptr[col_start]; - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i <= sb; i++) { - off[i] = - checked_int_span((size_t)(h_indptr[col_start + i] - ptr_start), - "OVO host CSC rebased column offsets"); - } - } - int* d_all_offsets = - pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); - cudaMemcpy(d_all_offsets, h_all_offsets.data(), - h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + int* d_all_offsets = precompute_csc_batch_offsets( + h_indptr, n_cols, sub_batch_cols, n_batches, pool, + "OVO host CSC rebased column offsets"); // Row maps + group offsets + stats codes (uploaded once) int* d_ref_row_map = pool.alloc(n_rows); @@ -270,25 +254,17 @@ static void ovo_streaming_csc_host_impl( n_groups, compute_tie_corr, stream); // D2D: scatter sub-batch results into caller's GPU buffers - cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_rank_sums + col, buf.d_rank_sums, n_groups, n_cols, + sb_cols, stream); if (compute_tie_corr) { - cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), - buf.d_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_tie_corr + col, buf.d_tie_corr, n_groups, n_cols, + sb_cols, stream); } - cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), - buf.d_group_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_group_sums + col, buf.d_group_sums, n_groups_stats, + n_cols, sb_cols, stream); if (compute_nnz) { - cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), - buf.d_group_nnz, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_group_nnz + col, buf.d_group_nnz, n_groups_stats, + n_cols, sb_cols, stream); } col += sb_cols; @@ -751,17 +727,11 @@ static void ovo_streaming_csr_host_impl( pack_tpb_rank, n_ref, pack_rows, sb_cols, K, compute_tie_corr, stream); - cudaMemcpy2DAsync(d_rank_sums + (size_t)pack.first * n_cols + col, - n_cols * sizeof(double), buf.d_rank_sums, - sb_cols * sizeof(double), - sb_cols * sizeof(double), K, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_rank_sums + (size_t)pack.first * n_cols + col, + buf.d_rank_sums, K, n_cols, sb_cols, stream); if (compute_tie_corr) { - cudaMemcpy2DAsync( - d_tie_corr + (size_t)pack.first * n_cols + col, - n_cols * sizeof(double), buf.d_tie_corr, - sb_cols * sizeof(double), sb_cols * sizeof(double), K, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_tie_corr + (size_t)pack.first * n_cols + col, + buf.d_tie_corr, K, n_cols, sb_cols, stream); } col += sb_cols; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 65bbd69a..f66e4db8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -31,9 +31,7 @@ static void ovr_sparse_csc_host_streaming_impl( [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); } - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); size_t max_nnz = 0; for (int col = 0; col < n_cols; col += sub_batch_cols) { @@ -103,22 +101,9 @@ static void ovr_sparse_csc_host_streaming_impl( // Pre-compute rebased per-batch offsets and upload once (avoids per-batch // H2D copy from a transient host buffer). int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - for (int b = 0; b < n_batches; b++) { - int col_start = b * sub_batch_cols; - int sb = std::min(sub_batch_cols, n_cols - col_start); - IndptrT ptr_start = h_indptr[col_start]; - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i <= sb; i++) { - off[i] = - checked_int_span((size_t)(h_indptr[col_start + i] - ptr_start), - "OVR host CSC rebased column offsets"); - } - } - int* d_all_offsets = - pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); - cudaMemcpy(d_all_offsets, h_all_offsets.data(), - h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + int* d_all_offsets = precompute_csc_batch_offsets( + h_indptr, n_cols, sub_batch_cols, n_batches, pool, + "OVR host CSC rebased column offsets"); int tpb = UTIL_BLOCK_SIZE; bool rank_use_gmem = false; @@ -189,24 +174,18 @@ static void ovr_sparse_csc_host_streaming_impl( n_rows, sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); - cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_rank_sums + col, buf.d_rank_sums, n_groups, n_cols, + sb_cols, stream); if (compute_tie_corr) { cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } - cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), - buf.d_group_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_group_sums + col, buf.d_group_sums, n_groups, n_cols, + sb_cols, stream); if (compute_nnz) { - cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), - buf.d_group_nnz, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_group_nnz + col, buf.d_group_nnz, n_groups, + n_cols, sb_cols, stream); } col += sb_cols; @@ -613,9 +592,8 @@ static void ovr_sparse_csr_host_streaming_impl( int n_streams = N_STREAMS; if (n_batches < n_streams) n_streams = n_batches; size_t stream_budget = budget - resident - (data_staged ? data_bytes : 0); - while (n_streams > 1 && - (size_t)n_streams * per_stream_bytes > stream_budget) - n_streams--; + n_streams = + clamp_streams_by_budget(n_streams, per_stream_bytes, stream_budget); ScopedCudaStreams streams(n_streams, cudaStreamDefault); @@ -730,24 +708,18 @@ static void ovr_sparse_csr_host_streaming_impl( buf.d_nz_scratch, n_rows, sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); - cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_rank_sums + col, buf.sub_rank_sums, n_groups, n_cols, + sb_cols, stream); if (compute_tie_corr) { cudaMemcpyAsync(d_tie_corr + col, buf.sub_tie_corr, sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } - cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), - buf.sub_group_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_group_sums + col, buf.sub_group_sums, n_groups, + n_cols, sb_cols, stream); if (compute_nnz) { - cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), - buf.sub_group_nnz, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(d_group_nnz + col, buf.sub_group_nnz, n_groups, + n_cols, sb_cols, stream); } col += sb_cols; @@ -787,9 +759,7 @@ static void ovr_sparse_csc_streaming_impl( [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); } - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); size_t max_nnz = 0; for (int col = 0; col < n_cols; col += sub_batch_cols) { @@ -879,10 +849,8 @@ static void ovr_sparse_csc_streaming_impl( sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(rank_sums + col, buf.sub_rank_sums, n_groups, n_cols, + sb_cols, stream); if (compute_tie_corr) { cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, @@ -1002,8 +970,7 @@ static void ovr_sparse_csr_streaming_impl( } size_t budget = rmm_available_device_bytes(0.8); - while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) - n_streams--; + n_streams = clamp_streams_by_budget(n_streams, per_stream_bytes, budget); ScopedCudaStreams streams(n_streams, cudaStreamDefault); @@ -1082,10 +1049,8 @@ static void ovr_sparse_csr_streaming_impl( sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + scatter_cols_2d(rank_sums + col, buf.sub_rank_sums, n_groups, n_cols, + sb_cols, stream); if (compute_tie_corr) { cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 14d3ef0c..b9e232fe 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -164,11 +164,7 @@ __global__ void rank_sums_sparse_ovr_kernel( double* warp_buf = smem + warp_buf_off; double v = wilcoxon_block_sum(local_tie_sum, warp_buf); - if (threadIdx.x == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - tie_corr[col] = (denom > 0.0) ? (1.0 - v / denom) : 1.0; - } + if (threadIdx.x == 0) tie_corr[col] = finalize_tie_corr(n_rows, v); } } From b555a28e9506cdce13e38994624417483d18d8a5 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 22 Jun 2026 18:27:27 +0200 Subject: [PATCH 25/36] update testing --- docs/release-notes/0.16.0.md | 1 + .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 40 ++ .../_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh | 3 +- .../tools/_rank_genes_groups/__init__.py | 11 +- .../tools/_rank_genes_groups/_utils.py | 23 +- .../tools/_rank_genes_groups/_wilcoxon.py | 12 +- .../_rank_genes_groups/_wilcoxon_binned.py | 33 +- tests/test_rank_genes_groups_ttest.py | 13 +- tests/test_rank_genes_groups_wilcoxon.py | 429 +++++++++++++++++- .../test_rank_genes_groups_wilcoxon_binned.py | 34 +- 10 files changed, 557 insertions(+), 42 deletions(-) diff --git a/docs/release-notes/0.16.0.md b/docs/release-notes/0.16.0.md index 1109d619..fdf715e4 100644 --- a/docs/release-notes/0.16.0.md +++ b/docs/release-notes/0.16.0.md @@ -3,6 +3,7 @@ ```{rubric} Features ``` * Reworked GPU {func}`~rapids_singlecell.tl.rank_genes_groups` Wilcoxon onto dedicated nanobind CUDA kernels {pr}`636` {smaller}`S Dicks` +* {func}`~rapids_singlecell.tl.rank_genes_groups` no longer truncates gene names longer than 50 characters in ``uns[...]['names']`` (the field is now ``object`` dtype, matching Scanpy) {pr}`636` {smaller}`S Dicks` * Add {class}`~rapids_singlecell.ptg.Mixscape` for GPU-accelerated Mixscape (`perturbation_signature`, `mixscape`, `mixscale`, `lda`) {pr}`688` {smaller}`S Dicks` ```{rubric} Performance diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 608354a8..7843579b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -81,6 +81,28 @@ static void ovo_streaming_csr_impl( cub_temp_bytes = cub_grp_bytes; } + // Clamp streams to the per-stream scratch budget (mirrors host OVO): the + // group dense slab scales with the cell count, so a fixed stream count + // would OOM at scale. The reference cache is allocated separately, so + // reserve its footprint first. + { + size_t per_stream = + sub_grp_items * sizeof(float) + + (run_huge ? sub_grp_items * sizeof(float) : 0) + + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) + : 0) + + (run_huge ? cub_temp_bytes : 0) + + (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium) + ? (size_t)sub_batch_cols * sizeof(double) + : 0) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double); + size_t budget = rmm_available_device_bytes(0.8); + size_t ref_reserve = + 2 * (size_t)n_ref * (size_t)ref_cache_cols * sizeof(float); + budget = budget > ref_reserve ? budget - ref_reserve : 0; + n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); + } + ScopedCudaStreams streams(n_streams, cudaStreamDefault); ScopedCudaStream ref_stream(cudaStreamNonBlocking); @@ -285,6 +307,24 @@ static void ovo_streaming_csc_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } + // Clamp streams to the per-stream scratch budget (mirrors host OVO): the + // ref/group dense slabs scale with cell counts, so a fixed stream count + // would OOM at scale. + { + size_t per_stream = + 2 * sub_ref_items * sizeof(float) + + (run_huge ? 2 : 1) * sub_grp_items * sizeof(float) + + (size_t)(sub_batch_cols + 1) * sizeof(int) + cub_temp_bytes + + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) + : 0) + + (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium) + ? (size_t)sub_batch_cols * sizeof(double) + : 0) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double); + size_t budget = rmm_available_device_bytes(0.8); + n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); + } + // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; ScopedCudaStreams streams(n_streams, cudaStreamDefault); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh index 09398f30..9b5e7377 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh @@ -52,7 +52,8 @@ __device__ __forceinline__ double ovr_walk_tie_runs( int total_tie = tie_global_end - tie_global_start; double avg_rank = - rank_offset + (double)(tie_global_start + tie_global_end + 1) / 2.0; + rank_offset + + ((double)tie_global_start + (double)tie_global_end + 1.0) / 2.0; for (int j = i; j < tie_local_end; ++j) { int grp = group_codes[si[j]]; diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 3ddaf067..c2c7f1b3 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -77,9 +77,10 @@ def rank_genes_groups( Rank genes for characterizing groups using GPU acceleration. Log1p/log-normalized data is expected for biologically meaningful log fold - changes. Sparse inputs with explicit negative values fall back to the dense - full-sort ranking path; dense inputs are ranked directly and support any - sign. + changes. In-memory sparse inputs with explicit negative values fall back to + the dense full-sort ranking path; dense inputs are ranked directly and + support any sign. (``wilcoxon_binned`` rejects negative Dask sparse input, + which it cannot bin correctly.) .. note:: **Dask support:** `'t-test'`, `'t-test_overestim_var'`, @@ -238,6 +239,10 @@ def rank_genes_groups( msg = "return_u_values is only supported for method='wilcoxon'." raise ValueError(msg) + if chunk_size is not None and chunk_size <= 0: + msg = "chunk_size must be a positive integer." + raise ValueError(msg) + if key_added is None: key_added = "rank_genes_groups" diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 3814f662..796354d0 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -18,15 +18,14 @@ def _sparse_has_negative(X) -> bool: - """Whether X is a sparse matrix holding an explicit negative value. - - The optimized sparse Wilcoxon paths rank explicit nonzeros and add the - implicit (structural) zeros analytically as a tie at the column minimum, - which is correct only when every stored value is nonnegative (counts / - log1p-normalized data). With a negative stored value the implicit zeros are - no longer the minimum, so that analytic ranking is wrong and the caller - must fall back to the dense full-sort path (valid for any sign). Dense - inputs and the t-test/logreg methods never need this. + """Whether an in-memory sparse ``X`` stores an explicit negative value. + + The fast sparse Wilcoxon paths add implicit (structural) zeros as a tie at + the column minimum, which is correct only for nonnegative stored values. A + negative breaks that, so the in-memory Wilcoxon paths fall back to the dense + full-sort path (valid for any sign). Dask arrays are not inspected here + (they are neither ``scipy`` nor ``cupy`` sparse); ``wilcoxon_binned`` guards + Dask sparse separately. Dense and t-test/logreg never need this. """ if sp.issparse(X) or cpsp.issparse(X): return X.nnz > 0 and float(X.data.min()) < 0 @@ -81,10 +80,8 @@ def _select_groups( if selected is None: selected = list(all_categories) - elif len(selected) > 1: - # Sort to match original category order (scanpy convention) - cat_order = {str(c): i for i, c in enumerate(all_categories)} - selected.sort(key=lambda x: cat_order.get(str(x), len(all_categories))) + # else: preserve the user-provided order. scanpy's select_groups does NOT + # re-sort to category order, so the output column order echoes `groups=`. if skip_empty_groups: counts = { diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index b1fb9168..8189a0c2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -38,11 +38,13 @@ def _maybe_preload_host_dense(rg: _RankGenes) -> None: return try: - _, total = cp.cuda.runtime.memGetInfo() + free, _ = cp.cuda.runtime.memGetInfo() except cp.cuda.runtime.CUDARuntimeError: return - if X.nbytes > total * DENSE_HOST_PRELOAD_MAX_GPU_FRACTION: + # Gate on *free* (the rmm_available_device_bytes convention), not total: + # under an RMM pool an array below a fraction of total can still exceed free. + if X.nbytes > free * DENSE_HOST_PRELOAD_MAX_GPU_FRACTION: return registered = False @@ -56,11 +58,11 @@ def _maybe_preload_host_dense(rg: _RankGenes) -> None: try: X_gpu = cp.asarray(X) cp.cuda.get_current_stream().synchronize() - except cp.cuda.memory.OutOfMemoryError: + # Under RMM an OOM surfaces as a bare MemoryError (std::bad_alloc), which + # also subsumes cupy's OutOfMemoryError subclass. + except (MemoryError, cp.cuda.runtime.CUDARuntimeError): cp.get_default_memory_pool().free_all_blocks() return - except cp.cuda.runtime.CUDARuntimeError: - return finally: if registered: try: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index f956f9db..97308d4f 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -103,8 +103,8 @@ def wilcoxon_binned( ``'log1p'`` for Dask arrays (to avoid a costly data scan). ``'log1p'`` uses a fixed [0, 15] range suitable for log1p-normalized data. - ``'auto'`` computes the actual (min, max) of the data. Use this - for nonnegative expression data outside the fixed log1p range. + ``'auto'`` computes the actual (min, max) of the data, spanning + negatives when present. Use it for data outside the fixed log1p range. """ if not rg.is_log1p: warnings.warn( @@ -125,6 +125,35 @@ def wilcoxon_binned( n_cells, n_genes = X.shape group_sizes = rg.group_sizes + # Dask sparse cannot bin negatives correctly: the sparse histogram puts + # implicit zeros in the lowest bin and _data_range floors the range at 0 + # for Dask sparse, so negatives would be silently mis-ranked. Refuse rather + # than return wrong numbers (in-memory sparse negatives use the dense + # fallback; see _sparse_has_negative). + if isinstance(X, DaskArray) and cpsp.issparse(X._meta): + + def _block_data_min(block): + if block.nnz > 0: + return block.data.min().reshape(1) + return cp.zeros(1, dtype=block.dtype) + + data_min = float( + X.map_blocks( + _block_data_min, + dtype=X.dtype, + drop_axis=1, + chunks=((1,) * len(X.chunks[0]),), + ) + .min() + .compute() + ) + if data_min < 0: + raise ValueError( + "wilcoxon_binned does not support negative values in Dask " + "sparse input; the binned approximation mis-ranks implicit " + "zeros. Densify the data or use a nonnegative representation." + ) + # group_codes: 0..n_groups-1 for selected cells, n_groups (sentinel) # for unselected. For vs-rest, unselected cells are binned into a # dummy group so they contribute to total counts for correct midranks. diff --git a/tests/test_rank_genes_groups_ttest.py b/tests/test_rank_genes_groups_ttest.py index 7f109e24..b076c0be 100644 --- a/tests/test_rank_genes_groups_ttest.py +++ b/tests/test_rank_genes_groups_ttest.py @@ -254,7 +254,8 @@ def test_rank_genes_groups_ttest_with_renamed_categories( @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("method", ["t-test", "t-test_overestim_var"]) def test_rank_genes_groups_ttest_with_unsorted_groups(reference, method): - """Test that group order doesn't affect results.""" + """Group order sets the output column order (matching scanpy); the per-group + statistics themselves are order-independent.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) adata.obs["blobs"] = adata.obs["blobs"].astype("category") @@ -271,9 +272,13 @@ def test_rank_genes_groups_ttest_with_unsorted_groups(reference, method): bdata, "blobs", method=method, groups=groups_reversed, reference=reference ) - expected_groups = {g for g in groups if g != reference} - assert set(adata.uns["rank_genes_groups"]["names"].dtype.names) == expected_groups - assert set(bdata.uns["rank_genes_groups"]["names"].dtype.names) == expected_groups + # Column order echoes the user-provided group order (reference excluded). + assert adata.uns["rank_genes_groups"]["names"].dtype.names == tuple( + g for g in groups if g != reference + ) + assert bdata.uns["rank_genes_groups"]["names"].dtype.names == tuple( + g for g in groups_reversed if g != reference + ) # Pick a group that's not the reference for comparison test_group = "3" if reference != "3" else "0" diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index e65a89db..e338be4e 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1023,7 +1023,8 @@ def test_rank_genes_groups_wilcoxon_with_renamed_categories( @pytest.mark.parametrize("reference", ["rest", "1"]) def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): - """Test that group order doesn't affect results.""" + """Group order sets the output column order (matching scanpy); the per-group + statistics themselves are order-independent.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) _make_nonnegative(adata) @@ -1040,9 +1041,13 @@ def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): bdata, "blobs", method="wilcoxon", groups=groups_reversed, reference=reference ) - expected_groups = {g for g in groups if g != reference} - assert set(adata.uns["rank_genes_groups"]["names"].dtype.names) == expected_groups - assert set(bdata.uns["rank_genes_groups"]["names"].dtype.names) == expected_groups + # Column order echoes the user-provided group order (reference excluded). + assert adata.uns["rank_genes_groups"]["names"].dtype.names == tuple( + g for g in groups if g != reference + ) + assert bdata.uns["rank_genes_groups"]["names"].dtype.names == tuple( + g for g in groups_reversed if g != reference + ) # Pick a group that's not the reference for comparison test_group = "3" if reference != "3" else "0" @@ -1475,3 +1480,419 @@ def test_rank_genes_groups_sparse_negative_values_fallback_ovo(fmt): rtol=1e-13, atol=1e-13, ) + + +@pytest.mark.parametrize("reference", ["rest", "2"]) +def test_wilcoxon_group_subset_column_order_matches_scanpy(reference): + """Output column order must echo the user's ``groups=`` list (scanpy parity), + not be re-sorted to category order.""" + np.random.seed(0) + adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) + bdata = adata.copy() + + # Deliberately out-of-category-order subset. + groups = ["3", "1"] if reference != "rest" else ["3", "1", "0"] + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + groups=groups, + reference=reference, + ) + sc.tl.rank_genes_groups( + bdata, + "blobs", + method="wilcoxon", + use_raw=False, + groups=groups, + reference=reference, + ) + assert ( + adata.uns["rank_genes_groups"]["names"].dtype.names + == bdata.uns["rank_genes_groups"]["names"].dtype.names + ) + + +def test_wilcoxon_host_sparse_negative_chunked_stats_match_scanpy(): + """Host scipy-sparse with negatives takes the dense fallback, whose group + means/vars/pts run through the group_chunk_stats kernel (multi-chunk). Those + means (-> logfoldchanges) and pts must match scanpy. + """ + rng = np.random.default_rng(0) + n_obs, n_vars = 200, 24 + X = (rng.random((n_obs, n_vars)) * 5.0).astype(np.float64) + X[X < 1.5] = 0.0 # structural zeros so pts < 1 + X[rng.random((n_obs, n_vars)) < 0.01] = -0.5 # a few negatives -> fallback + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(n_obs)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_vars)]) + + gpu = sc.AnnData(X=sp.csr_matrix(X), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=X.copy(), obs=obs.copy(), var=var.copy()) + + rsc.tl.rank_genes_groups( + gpu, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + pts=True, + n_genes=n_vars, + chunk_size=8, # < n_vars -> multiple chunks + ) + sc.tl.rank_genes_groups( + cpu, "group", method="wilcoxon", use_raw=False, reference="rest", pts=True + ) + g = gpu.uns["rank_genes_groups"] + c = cpu.uns["rank_genes_groups"] + for group in g["names"].dtype.names: + g_lfc = dict( + zip(g["names"][group], np.asarray(g["logfoldchanges"][group], float)) + ) + c_lfc = dict( + zip(c["names"][group], np.asarray(c["logfoldchanges"][group], float)) + ) + for gene, val in g_lfc.items(): + np.testing.assert_allclose( + val, c_lfc[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + for frame in ("pts", "pts_rest"): + for col in c[frame].columns: + np.testing.assert_allclose( + g[frame].loc[c[frame].index, col].values, + c[frame][col].values, + rtol=1e-12, + atol=1e-13, + ) + + +def test_wilcoxon_fdr_ties_nan_match_scanpy(): + """BH FDR must match scanpy on heavily-tied / constant / all-zero genes, + locking in that the GPU argsort tie-break is inert for adjusted p-values. + Integer data is float32-exact, so ranking is bit-identical to scanpy and the + comparison isolates the FDR step. + """ + rng = np.random.default_rng(1) + n_obs, n_vars = 240, 30 + X = rng.integers(0, 3, size=(n_obs, n_vars)).astype(np.float64) # heavy ties + X[:, 0] = 1.0 # constant gene -> identical p across groups + X[:, 1] = 0.0 # all-zero gene + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(n_obs)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_vars)]) + + gpu = sc.AnnData(X=cp.asarray(X), obs=obs.copy(), var=var.copy()) # GPU FDR path + cpu = sc.AnnData(X=X.copy(), obs=obs.copy(), var=var.copy()) + + rsc.tl.rank_genes_groups( + gpu, "group", method="wilcoxon", use_raw=False, tie_correct=True + ) + sc.tl.rank_genes_groups( + cpu, "group", method="wilcoxon", use_raw=False, tie_correct=True + ) + g = gpu.uns["rank_genes_groups"] + c = cpu.uns["rank_genes_groups"] + for group in g["names"].dtype.names: + g_adj = dict(zip(g["names"][group], np.asarray(g["pvals_adj"][group], float))) + c_adj = dict(zip(c["names"][group], np.asarray(c["pvals_adj"][group], float))) + for gene, val in g_adj.items(): + np.testing.assert_allclose( + val, c_adj[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +def _promote_host_index_dtypes(X, *, indptr64, indices64): + """Copy a host scipy CSR/CSC matrix with promoted index-array dtypes. + + scipy couples indptr/indices to one dtype via get_index_dtype, so the + decoupled (i64 indptr / i32 indices) combination only arises by explicit + promotion -- which is exactly what drives the templated host kernels. + """ + X = X.copy() + if indptr64: + X.indptr = X.indptr.astype(np.int64) + if indices64: + X.indices = X.indices.astype(np.int64) + return X + + +@pytest.mark.parametrize("reference", ["rest", "1"]) # OVR vs OVO host paths +@pytest.mark.parametrize( + ("layout", "data_dtype", "indices64"), + [ + ("csr", np.float32, False), # *_i64 + ("csr", np.float32, True), # *_i64_idx64 + ("csr", np.float64, False), # *_f64_i64 + ("csr", np.float64, True), # *_f64_i64_idx64 + ("csc", np.float32, False), # *_i64 (CSC has no idx64 template) + ("csc", np.float64, False), # *_f64_i64 + ], +) +def test_host_sparse_int64_templates_match_int32( + reference, layout, data_dtype, indices64 +): + """Exercise the host-sparse int64-indptr / int64-indices kernel templates + (the 12 ``*_i64`` / ``*_idx64`` / ``*_f64_i64`` host bindings the suite + otherwise never reaches). These differ from the validated int32 host path + only in index dtype, so they must be bit-identical to it. Real int64 indices + only occur at nnz > 2^31 (unallocatable in CI), so we promote a small + matrix's index arrays explicitly and keep it host-resident (scipy sparse + + method='wilcoxon' is not moved to GPU).""" + rng = np.random.default_rng(0) + dense = (rng.random((150, 8)) * 4.0).astype(np.float64) + dense[dense < 1.5] = 0.0 # nonnegative + structural zeros -> sparse fast path + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(150)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(8)]) + + maker = sp.csr_matrix if layout == "csr" else sp.csc_matrix + base = maker(dense.astype(data_dtype)) + + a32 = sc.AnnData(X=base.copy(), obs=obs.copy(), var=var.copy()) + a64 = sc.AnnData( + X=_promote_host_index_dtypes(base, indptr64=True, indices64=indices64), + obs=obs.copy(), + var=var.copy(), + ) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(a32, "group", **kw) + rsc.tl.rank_genes_groups(a64, "group", **kw) + + r32, r64 = a32.uns["rank_genes_groups"], a64.uns["rank_genes_groups"] + assert r64["names"].dtype.names == r32["names"].dtype.names + for fld in ("scores", "pvals", "pvals_adj", "logfoldchanges"): + for grp in r32[fld].dtype.names: + np.testing.assert_array_equal( + np.asarray(r64[fld][grp]), np.asarray(r32[fld][grp]) + ) + + +def _anndata_with_group_sizes(sizes, *, n_genes=6, seed=0): + """Dense AnnData whose per-group cell counts are exactly ``sizes``. + + The OVO tier dispatch picks the rank kernel by *test-group* size + (WARP<=32, SMALL 33-64, MEDIUM 65-512, LARGE 513-2500, HUGE>2500), so + engineered group sizes drive specific bands. Integer data is float32-exact, + so ranking is bit-identical to scanpy. + """ + rng = np.random.default_rng(seed) + labels = [] + for name, n in sizes.items(): + labels += [name] * n + X = rng.integers(0, 6, size=(len(labels), n_genes)).astype(np.float64) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_genes)]) + return sc.AnnData(X=X, obs=obs, var=var) + + +def _assert_ovo_matches_scanpy(adata, reference): + bdata = adata.copy() + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata, "group", **kw) + sc.tl.rank_genes_groups(bdata, "group", **kw) + g, c = adata.uns["rank_genes_groups"], bdata.uns["rank_genes_groups"] + for fld in ("scores", "pvals", "pvals_adj"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +def test_ovo_tier_bands_warp_small_medium_large_match_scanpy(): + """OVO dense-tiered path must hit the WARP/SMALL/MEDIUM/LARGE rank kernels + (test-group sizes 20/50/300/1000, all <= 2500) and match scanpy.""" + adata = _anndata_with_group_sizes( + {"ref": 40, "warp": 20, "small": 50, "medium": 300, "large": 1000}, seed=1 + ) + _assert_ovo_matches_scanpy(adata, reference="ref") + + +def test_ovo_tier_band_huge_match_scanpy(): + """OVO dense-tiered path must hit the HUGE band (CUB segmented sort, a + test-group > 2500 cells) and match scanpy.""" + adata = _anndata_with_group_sizes({"ref": 40, "huge": 3000}, seed=2) + _assert_ovo_matches_scanpy(adata, reference="ref") + + +@pytest.mark.filterwarnings("ignore::RuntimeWarning") # 6200 tiny groups warn +def test_ovr_dense_gmem_branch_matches_scipy(): + """The DENSE OVR global-memory accumulator (use_gmem) engages only when the + per-block group accumulators exceed the 48 KB MaxSharedMemoryPerBlock limit + -- n_groups > 6112 (= 49152 / 8 - 32). No other test reaches it (the + >3056-group gmem tests only flip the *sparse* accumulator). n_groups=6200 + deterministically routes through dense gmem; a scanpy oracle here costs ~30s + (its per-group Python loop), so we validate a sample of groups against scipy + mannwhitneyu with rsc's exact settings (tie-corrected asymptotic, no + continuity).""" + from scipy.stats import mannwhitneyu + + n_groups, n_genes = 6200, 4 # > 6112 -> dense gmem accumulator + rng = np.random.default_rng(3) + labels = np.repeat(np.arange(n_groups), 2) # 2 cells per group + X = rng.integers(0, 6, size=(labels.size, n_genes)).astype(np.float64) + obs = pd.DataFrame({"group": pd.Categorical([str(x) for x in labels])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_genes)]) + adata = sc.AnnData(X=X, obs=obs, var=var) + + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", use_raw=False, tie_correct=True + ) + res = adata.uns["rank_genes_groups"] + for grp in ("0", "1", "250", "1000", "3057", "6112", "6199"): + gp = dict(zip(res["names"][grp], np.asarray(res["pvals"][grp], float))) + mask = labels == int(grp) + for gi, gene in enumerate(var.index): + _, p = mannwhitneyu( + X[mask, gi], + X[~mask, gi], + use_continuity=False, + alternative="two-sided", + method="asymptotic", + ) + np.testing.assert_allclose( + gp[gene], + p, + rtol=1e-10, + atol=1e-12, + equal_nan=True, + err_msg=f"group {grp} gene {gene}", + ) + + +def test_skip_empty_groups_vs_rest_drops_singleton(): + """skip_empty_groups=True with reference='rest' silently drops <2-cell + groups (covers the reference=='rest' branch of _select_groups, which the + existing reference='ref' skip tests miss).""" + adata = _anndata_with_group_sizes({"a": 10, "b": 10, "c": 1}, seed=4) + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", use_raw=False, skip_empty_groups=True + ) + names = set(adata.uns["rank_genes_groups"]["names"].dtype.names) + assert names == {"a", "b"} # singleton "c" dropped, no error + + +def test_skip_empty_groups_reference_too_small_raises(): + """skip_empty_groups=True with a <2-cell reference raises a clear error.""" + adata = _anndata_with_group_sizes({"a": 10, "b": 10, "c": 1}, seed=4) + with pytest.raises(ValueError, match="reference = c has fewer than two samples"): + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference="c", + skip_empty_groups=True, + ) + + +def test_skip_empty_groups_none_remain_raises(): + """skip_empty_groups=True raises when no group has >=2 cells (vs-rest).""" + adata = _anndata_with_group_sizes({"a": 1, "b": 1, "c": 1}, seed=4) + with pytest.raises(ValueError, match="No groups with at least two samples remain"): + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", use_raw=False, skip_empty_groups=True + ) + + +@pytest.mark.parametrize( + "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] +) +def test_ovr_tie_correct_false_tie_heavy_matches_scanpy(fmt): + """OVR (reference='rest') tie_correct=False on TIE-HEAVY data must match + scanpy across every storage format. The pre-existing tie_correct=False oracle + uses tie-free blobs (tie_corr ~= 1), so a wrong uncorrected variance would + pass; integer data with many ties stresses the omitted tie term on each path + (dense, host-sparse CSR/CSC, device-sparse CSR/CSC).""" + rng = np.random.default_rng(7) + n_obs, n_genes = 180, 8 + dense = rng.integers(0, 5, size=(n_obs, n_genes)).astype(np.float64) # ties + dense[dense < 1.0] = 0.0 # nonnegative + structural zeros -> sparse fast path + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(n_obs)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_genes)]) + + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": "rest", + "tie_correct": False, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for fld in ("scores", "pvals"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +@pytest.mark.parametrize( + "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] +) +@pytest.mark.parametrize("reference", ["rest", "1"]) # OVR and OVO epilogues +def test_use_continuity_matches_scipy(fmt, reference): + """use_continuity=True is validated only on dense-OVO elsewhere. Check the + continuity epilogue composes correctly with each path's rank_sums (OVR + OVO, + every format) vs scipy.mannwhitneyu(use_continuity=True, asymptotic). + + Groups OVERLAP (no separation) on purpose: that keeps |R-E[R]| moderate so + the 0.5 continuity term MATERIALLY changes p -- a missing continuity + correction would then fail the scipy oracle (non-vacuous). Because U and E[U] + are multiples of 0.5, rsc's clamp (max(|d|-0.5,0)) and scipy's shift agree + exactly. tie_correct=True matches scipy's always-on asymptotic tie term.""" + from scipy.stats import mannwhitneyu + + rng = np.random.default_rng(8) + n_obs, n_genes = 150, 6 + # Overlapping groups (same distribution) -> moderate |R-E[R]| -> continuity + # is material. Integer values give ties (exercises the tie term too). + dense = rng.integers(0, 4, size=(n_obs, n_genes)).astype(np.float64) + labels = np.array([str(i % 3) for i in range(n_obs)]) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_genes)]) + + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + rsc.tl.rank_genes_groups( + gpu, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + tie_correct=True, + use_continuity=True, + n_genes=n_genes, + ) + res = gpu.uns["rank_genes_groups"] + for grp in res["names"].dtype.names: + gm = dict(zip(res["names"][grp], np.asarray(res["pvals"][grp], float))) + mask_g = labels == grp + mask_r = (labels != grp) if reference == "rest" else (labels == reference) + for gi, gene in enumerate(var.index): + _, p = mannwhitneyu( + dense[mask_g, gi], + dense[mask_r, gi], + use_continuity=True, + alternative="two-sided", + method="asymptotic", + ) + np.testing.assert_allclose( + gm[gene], p, rtol=1e-10, atol=1e-12, equal_nan=True + ) diff --git a/tests/test_rank_genes_groups_wilcoxon_binned.py b/tests/test_rank_genes_groups_wilcoxon_binned.py index f0e6848d..01bd47ac 100644 --- a/tests/test_rank_genes_groups_wilcoxon_binned.py +++ b/tests/test_rank_genes_groups_wilcoxon_binned.py @@ -429,21 +429,35 @@ def test_sparse_with_actual_zeros(self, adata_blobs): assert np.all(pvals <= 1) def test_sparse_negative_values_fallback(self, adata_blobs): - """Sparse input with negatives must not use the sparse histogram (which - assigns implicit zeros to bin 0, valid only for nonnegative data); it - falls back to the dense histogram, so the result matches the dense run. + """Sparse input with negatives must densify: the sparse histogram puts + implicit zeros in bin 0 (valid only for nonnegative data). A *correct* + fallback (densify) matches the dense run; a removed fallback would bin + the implicit zeros below stored negatives and diverge -- so this + assertion fails without the fallback. + + Sensitivity hinges on columns holding BOTH structural zeros AND a value + below them (a negative). Where the zeros are the column minimum, moving + them to bin 0 leaves their rank order unchanged and the binned z is + invariant (which is why a naive sparse-vs-dense check is vacuous). """ import cupy as cp import cupyx.scipy.sparse as cpsp - adata = adata_blobs.copy() - rsc.get.anndata_to_GPU(adata) - dense = cp.asarray(adata.X, dtype=cp.float64) - dense[:, 0] = -1.0 - - sparse_adata = adata.copy() + rng = np.random.default_rng(0) + n_obs, n_vars = adata_blobs.shape + base = (rng.random((n_obs, n_vars)) * 5.0).astype(np.float64) + base[base < 2.0] = 0.0 # real structural zeros (~40%) + # Negatives in zero-bearing cells: columns then hold structural zeros + # AND values below them, the case the fallback must rank correctly. + neg = (base == 0.0) & (rng.random(base.shape) < 0.05) + base[neg] = -0.5 + base[0, 1] = 10.0 # keep a positive max so sparse/dense ranges agree + assert (base == 0).any() and (base < 0).any() and (base > 0).any() + + dense = cp.asarray(base) + sparse_adata = adata_blobs.copy() sparse_adata.X = cpsp.csr_matrix(dense) - dense_adata = adata.copy() + dense_adata = adata_blobs.copy() dense_adata.X = dense rsc.tl.rank_genes_groups( From bc8bc9e6c8fb407b540a82090b8e63f5bd414476 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 22 Jun 2026 18:45:42 +0200 Subject: [PATCH 26/36] add more tests --- tests/dask/test_dask_rank_wilcoxon_binned.py | 69 ++++ tests/test_rank_genes_groups_wilcoxon.py | 328 ++++++++++++++++++ .../test_rank_genes_groups_wilcoxon_binned.py | 126 +++++++ 3 files changed, 523 insertions(+) diff --git a/tests/dask/test_dask_rank_wilcoxon_binned.py b/tests/dask/test_dask_rank_wilcoxon_binned.py index 5f49fb6b..dc23b317 100644 --- a/tests/dask/test_dask_rank_wilcoxon_binned.py +++ b/tests/dask/test_dask_rank_wilcoxon_binned.py @@ -137,3 +137,72 @@ def test_wilcoxon_binned_dask_reference(client, data_kind): ) _compare_scores(adata.uns["rank_genes_groups"], dask_data.uns["rank_genes_groups"]) + + +@pytest.mark.parametrize("data_kind", ["sparse", "dense"]) +def test_wilcoxon_binned_dask_auto_range(client, data_kind): + """bin_range='auto' exercises the Dask _data_range branches (per-block + min/max via map_blocks), which the bin_range='log1p' tests never reach.""" + adata, dask_data, groupby = _setup_data(data_kind) + + for ad_ in (adata, dask_data): + rsc.tl.rank_genes_groups( + ad_, + groupby=groupby, + method="wilcoxon_binned", + n_bins=200, + bin_range="auto", + use_raw=False, + ) + + _compare_scores(adata.uns["rank_genes_groups"], dask_data.uns["rank_genes_groups"]) + + +@pytest.mark.parametrize("data_kind", ["sparse", "dense"]) +def test_wilcoxon_binned_dask_reference_subset(client, data_kind): + """Dask + reference + groups-subset together (the has_unselected Dask + branch where unselected cells coexist with a reference group).""" + adata, dask_data, groupby = _setup_data(data_kind) + cats = [str(c) for c in adata.obs[groupby].cat.categories] + groups = cats[1:4] # subset that excludes the reference -> unselected cells exist + reference = cats[0] + + for ad_ in (adata, dask_data): + rsc.tl.rank_genes_groups( + ad_, + groupby=groupby, + method="wilcoxon_binned", + groups=groups, + reference=reference, + n_bins=1000, + bin_range="log1p", + use_raw=False, + ) + + _compare_scores(adata.uns["rank_genes_groups"], dask_data.uns["rank_genes_groups"]) + + +def test_wilcoxon_binned_dask_negative_sparse_raises(client): + """Dask sparse input with a stored negative is refused (the binned histogram + cannot place implicit zeros correctly for signed data).""" + import anndata as ad_mod + import pandas as pd + import scipy.sparse as sp + + rng = np.random.default_rng(0) + X = np.abs(rng.standard_normal((60, 8))).astype(np.float32) + X[X < 0.5] = 0.0 + X[0, 0] = -1.0 # one stored negative + obs = pd.DataFrame({"g": pd.Categorical([f"{i % 3}" for i in range(60)])}) + var = pd.DataFrame(index=[f"v{i}" for i in range(8)]) + adata = ad_mod.AnnData(X=sp.csr_matrix(X), obs=obs, var=var) + adata.X = as_sparse_cupy_dask_array(adata.X).persist() + + with pytest.raises(ValueError, match="negative values in Dask sparse"): + rsc.tl.rank_genes_groups( + adata, + groupby="g", + method="wilcoxon_binned", + bin_range="auto", + use_raw=False, + ) diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index e338be4e..9f2f7a66 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1896,3 +1896,331 @@ def test_use_continuity_matches_scipy(fmt, reference): np.testing.assert_allclose( gm[gene], p, rtol=1e-10, atol=1e-12, equal_nan=True ) + + +# --------------------------------------------------------------------------- +# Entry-point / init validation (rank_genes_groups + _RankGenes + _select_groups) +# --------------------------------------------------------------------------- + + +def test_rank_genes_groups_default_method_is_ttest(): + """Omitting method= defaults to t-test (rank_genes_groups path).""" + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + rsc.tl.rank_genes_groups(adata, "group", use_raw=False) + assert adata.uns["rank_genes_groups"]["params"]["method"] == "t-test" + + +@pytest.mark.parametrize( + ("override", "exc", "match"), + [ + ({"method": "nope"}, ValueError, "method must be one of"), + ({"corr_method": "foo"}, ValueError, "corr_method must be either"), + ({"chunk_size": 0}, ValueError, "chunk_size must be a positive integer"), + ({"chunk_size": -4}, ValueError, "chunk_size must be a positive integer"), + ({"groups": "0"}, ValueError, "Specify a sequence of groups"), + ({"reference": "ZZ"}, ValueError, "needs to be one of groupby"), + ], +) +def test_rank_genes_groups_invalid_args_raise(override, exc, match): + """Public-API argument validation raises (covers __init__/_core guards).""" + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + kwargs = {"method": "wilcoxon", "use_raw": False, **override} + with pytest.raises(exc, match=match): + rsc.tl.rank_genes_groups(adata, "group", **kwargs) + + +def test_rank_genes_groups_mask_var_missing_key_raises(): + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + with pytest.raises(KeyError, match="not found in adata.var"): + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", use_raw=False, mask_var="nope" + ) + + +def test_rank_genes_groups_mask_var_wrong_shape_raises(): + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + with pytest.raises(ValueError, match="mask_var has wrong shape"): + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + mask_var=np.ones(adata.n_vars + 3, dtype=bool), + ) + + +def test_rank_genes_groups_layer_and_use_raw_conflict_raises(): + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + adata.layers["L"] = adata.X.copy() + with pytest.raises(ValueError, match="Cannot specify .layer. and have"): + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", layer="L", use_raw=True + ) + + +def test_rank_genes_groups_use_raw_without_raw_raises(): + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + with pytest.raises(ValueError, match="is empty"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=True) + + +def test_singleton_group_without_skip_raises(): + """Non-skip path: a <2-cell group raises in _select_groups (line 131-135).""" + adata = _anndata_with_group_sizes({"a": 10, "b": 10, "c": 1}, seed=5) + with pytest.raises(ValueError, match="fewer than two samples"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +@pytest.mark.parametrize("use_raw", [None, True]) +def test_rank_genes_groups_reads_raw_matches_scanpy(use_raw): + """use_raw=None (raw present) and use_raw=True both read adata.raw. X is + overwritten with rank-scrambling noise so a path that wrongly read .X would + diverge from scanpy (non-vacuous).""" + adata = _anndata_with_group_sizes({"0": 30, "1": 30, "2": 30}, seed=6) + adata.raw = adata.copy() # raw holds the real signal + rng = np.random.default_rng(99) + adata.X = rng.integers(0, 6, size=adata.shape).astype(np.float64) # noise in .X + bdata = adata.copy() + kw = {"method": "wilcoxon", "use_raw": use_raw, "tie_correct": True} + rsc.tl.rank_genes_groups(adata, "group", **kw) + sc.tl.rank_genes_groups(bdata, "group", **kw) + g, c = adata.uns["rank_genes_groups"], bdata.uns["rank_genes_groups"] + for grp in g["scores"].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g["scores"][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c["scores"][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) # OVR (_core) + OVO (_wilcoxon) +@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr"]) +def test_log1p_base_logfoldchanges_match_scanpy(reference, fmt): + """A non-default log1p base changes expm1 in the logfoldchange computation + (_core.py:115 + the OVO host-sparse fast path _wilcoxon.py:232-234).""" + rng = np.random.default_rng(7) + dense = rng.integers(1, 6, size=(120, 6)).astype(np.float64) # nonneg, finite lfc + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(120)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + gpu.uns["log1p"] = {"base": 2.0} + cpu.uns["log1p"] = {"base": 2.0} + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for grp in g["logfoldchanges"].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g["logfoldchanges"][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c["logfoldchanges"][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-6, atol=1e-6, equal_nan=True + ) + + +# --------------------------------------------------------------------------- +# OVO / OVR parity & dispatch gaps +# --------------------------------------------------------------------------- + + +def test_ovo_dense_fallback_pts_match_scanpy(): + """OVO sparse-negative dense fallback computes pts via _fill_ovo_chunk_stats + (ref + group branches). Validate pts vs scanpy on the dense equivalent.""" + rng = np.random.default_rng(11) + dense = (rng.random((120, 8)) * 5.0).astype(np.float64) + dense[dense < 1.5] = 0.0 + dense[rng.random(dense.shape) < 0.01] = -0.5 # negatives -> dense fallback + obs = pd.DataFrame( + {"group": pd.Categorical(["a" if i % 2 else "b" for i in range(120)])} + ) + var = pd.DataFrame(index=[f"g{i}" for i in range(8)]) + gpu = sc.AnnData(X=sp.csr_matrix(dense), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": "b", + "pts": True, + "n_genes": 8, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for col in c["pts"].columns: + np.testing.assert_allclose( + g["pts"].loc[c["pts"].index, col].values, + c["pts"][col].values, + rtol=1e-12, + atol=1e-13, + ) + + +@pytest.mark.parametrize("fmt", ["numpy_dense", "cupy_csr"]) # CPU + GPU FDR epilogues +def test_bonferroni_matches_scanpy(fmt): + """Bonferroni correction (CPU _core.py:584 via dense, GPU :630-631 via the + cupy OVO result path) must match scanpy, not just be <=1 (the prior tests + only asserted the tautological clamp).""" + rng = np.random.default_rng(12) + dense = rng.integers(0, 5, size=(150, 6)).astype(np.float64) + dense[dense < 1.0] = 0.0 + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(150)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": "1", + "corr_method": "bonferroni", + "tie_correct": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for fld in ("scores", "pvals", "pvals_adj"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +def test_ovr_all_empty_csc_totals_runs(): + """All-zero host CSC + a groups= subset (leaves an unselected category) + + reference='rest' + pts=True exercises the empty-column totals branch.""" + dense = np.zeros((20, 5), dtype=np.float64) + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(20)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(5)]) + adata = sc.AnnData(X=sp.csc_matrix(dense), obs=obs, var=var) + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + groups=["0", "1"], + reference="rest", + pts=True, + ) + res = adata.uns["rank_genes_groups"] + for grp in res["scores"].dtype.names: + assert np.all(np.isfinite(np.asarray(res["scores"][grp], float))) + assert "pts_rest" in res + + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_ovr_fully_dense_column_match_scanpy(fmt): + """A column with no structural zeros (nnz==n_rows) hits the total_zero==0 + branch of the sparse OVR accumulate kernel. Validate vs scanpy.""" + rng = np.random.default_rng(13) + dense = rng.integers(0, 5, size=(90, 4)).astype(np.float64) + dense[dense < 1.0] = 0.0 + dense[:, 0] = rng.integers(1, 6, size=90) # column 0 strictly positive -> no zeros + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(90)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(4)]) + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": "rest", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for fld in ("scores", "pvals"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + +@pytest.mark.parametrize("fmt", ["cupy_csr", "cupy_csc"]) +def test_ovr_device_sparse_subset_match_scanpy(fmt): + """Device-sparse OVR with a groups= subset exercises the sentinel-group skip + in the device sparse kernels. Validate vs scanpy on the dense copy.""" + rng = np.random.default_rng(14) + dense = rng.integers(0, 6, size=(160, 6)).astype(np.float64) + dense[dense < 1.0] = 0.0 + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 4}" for i in range(160)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for fld in ("scores", "pvals"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_host_csc_int64_indices_cast_matches_int32(reference): + """Host CSC has no *_idx64 template, so int64 indices are cast to int32 + (_wilcoxon.py:355->357). Result must be bit-identical to the int32 input.""" + rng = np.random.default_rng(15) + dense = rng.integers(0, 5, size=(120, 6)).astype(np.float64) + dense[dense < 1.0] = 0.0 + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(120)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) + base = sp.csc_matrix(dense) + a32 = sc.AnnData(X=base.copy(), obs=obs.copy(), var=var.copy()) + m64 = base.copy() + m64.indices = m64.indices.astype(np.int64) # keep indptr int32 + a64 = sc.AnnData(X=m64, obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(a32, "group", **kw) + rsc.tl.rank_genes_groups(a64, "group", **kw) + r32, r64 = a32.uns["rank_genes_groups"], a64.uns["rank_genes_groups"] + for fld in ("scores", "pvals"): + for grp in r32[fld].dtype.names: + np.testing.assert_array_equal( + np.asarray(r64[fld][grp]), np.asarray(r32[fld][grp]) + ) + + +def test_device_sparse_float16_raises(): + """A cupy sparse matrix with float16 data raises a clear TypeError (the + device counterpart of the host float16 guard).""" + rng = np.random.default_rng(16) + dense = np.abs(rng.standard_normal((60, 4))).astype(np.float32) + dense[dense < 0.5] = 0.0 + mat = cpsp.csr_matrix(cp.asarray(dense)) + mat.data = mat.data.astype(cp.float16) + adata = sc.AnnData( + X=mat, + obs=pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(60)])}), + var=pd.DataFrame(index=[f"g{i}" for i in range(4)]), + ) + with pytest.raises(TypeError, match="float32"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) diff --git a/tests/test_rank_genes_groups_wilcoxon_binned.py b/tests/test_rank_genes_groups_wilcoxon_binned.py index 01bd47ac..fc05af15 100644 --- a/tests/test_rank_genes_groups_wilcoxon_binned.py +++ b/tests/test_rank_genes_groups_wilcoxon_binned.py @@ -566,3 +566,129 @@ def test_top_genes_match_scipy(adata_blobs): scipy_top = set(adata_blobs.var_names[np.argsort(pvals)[:n_top]]) overlap = len(binned_top & scipy_top) assert overlap >= n_top - 1, f"Group {group}: {overlap}/{n_top} overlap" + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_binned_bin_exact_matches_scipy(reference): + """wilcoxon_binned otherwise has NO external numeric oracle. With integer + data and n_bins >> value-range, each value gets its own bin -> binned ranks + == exact ranks -> binned pvals must match scipy.mannwhitneyu exactly + (tie_correct=True matches scipy's always-on asymptotic tie term). Covers + vs-rest and vs-ref, tie_correct and use_continuity, with non-vacuity + self-guards (each flag must materially change the result).""" + import pandas as pd + from scipy.stats import mannwhitneyu + + rng = np.random.default_rng(20) + n_obs, n_genes = 150, 6 + X = rng.integers(0, 5, size=(n_obs, n_genes)).astype(np.float32) # bin-exact + labels = np.array([str(i % 3) for i in range(n_obs)]) + genes = [f"v{i}" for i in range(n_genes)] + + def run(tie_correct, use_continuity): + a = sc.AnnData( + X=X.copy(), + obs=pd.DataFrame({"g": pd.Categorical(labels)}), + var=pd.DataFrame(index=genes), + ) + a.uns["log1p"] = {"base": None} # silence the log-norm warning + rsc.get.anndata_to_GPU(a) + rsc.tl.rank_genes_groups( + a, + "g", + method="wilcoxon_binned", + use_raw=False, + reference=reference, + tie_correct=tie_correct, + use_continuity=use_continuity, + n_bins=1000, + bin_range="auto", + n_genes=n_genes, + ) + r = a.uns["rank_genes_groups"] + return { + grp: dict(zip(r["names"][grp], np.asarray(r["pvals"][grp], float))) + for grp in r["names"].dtype.names + } + + # correctness vs scipy (tie_correct=True; both continuity settings) + for use_continuity in (False, True): + pv = run(True, use_continuity) + for grp, gm in pv.items(): + mask_g = labels == grp + mask_r = (labels != grp) if reference == "rest" else (labels == reference) + for gi, v in enumerate(genes): + _, p = mannwhitneyu( + X[mask_g, gi], + X[mask_r, gi], + use_continuity=use_continuity, + alternative="two-sided", + method="asymptotic", + ) + np.testing.assert_allclose( + gm[v], p, rtol=1e-6, atol=1e-6, equal_nan=True + ) + + # non-vacuity self-guards: each flag must materially change the result + def differs(a, b): + return any( + not np.isclose(a[g][v], b[g][v], rtol=1e-9, atol=1e-12) + for g in a + for v in a[g] + ) + + assert differs(run(True, True), run(True, False)), "use_continuity inert (vacuous)" + assert differs(run(True, False), run(False, False)), "tie_correct inert (vacuous)" + + +def test_binned_all_zero_sparse_finite(adata_blobs): + """All-zero in-memory sparse input (nnz==0 _data_range guard): no crash, all + pvals finite and 1.0 (every value in one bin -> z=0).""" + import cupy as cp + import cupyx.scipy.sparse as cpsp + + adata = adata_blobs.copy() + adata.X = cpsp.csr_matrix(cp.zeros(adata.shape, dtype=cp.float32)) + rsc.tl.rank_genes_groups(adata, "blobs", method="wilcoxon_binned", use_raw=False) + res = adata.uns["rank_genes_groups"] + for grp in res["pvals"].dtype.names: + p = np.asarray(res["pvals"][grp], dtype=float) + assert np.all(np.isfinite(p)) and np.allclose(p, 1.0) + + +def test_binned_log1p_invalid_for_negative_sparse_coerces_to_auto(adata_blobs): + """Sparse input with negatives + bin_range='log1p' warns and coerces to + 'auto' (the fixed [0,15] range would clamp negatives). Result must equal the + explicit 'auto' run (non-vacuous: no coercion -> mis-binned negatives differ).""" + import cupy as cp + import cupyx.scipy.sparse as cpsp + + rng = np.random.default_rng(21) + n_obs, n_vars = adata_blobs.shape + base = (rng.random((n_obs, n_vars)) * 4.0).astype(np.float64) + base[base < 1.5] = 0.0 + neg = (base == 0.0) & (rng.random(base.shape) < 0.05) + base[neg] = -0.5 + base[0, 1] = 10.0 # positive max so sparse/dense ranges align + dense = cp.asarray(base) + + sp_ad = adata_blobs.copy() + sp_ad.X = cpsp.csr_matrix(dense) + auto_ad = adata_blobs.copy() + auto_ad.X = cpsp.csr_matrix(dense) + with pytest.warns(RuntimeWarning, match="bin_range='log1p' is invalid"): + rsc.tl.rank_genes_groups( + sp_ad, "blobs", method="wilcoxon_binned", use_raw=False, bin_range="log1p" + ) + rsc.tl.rank_genes_groups( + auto_ad, "blobs", method="wilcoxon_binned", use_raw=False, bin_range="auto" + ) + sp_s = sp_ad.uns["rank_genes_groups"]["scores"] + au_s = auto_ad.uns["rank_genes_groups"]["scores"] + for grp in sp_s.dtype.names: + np.testing.assert_allclose( + np.asarray(sp_s[grp], dtype=float), + np.asarray(au_s[grp], dtype=float), + rtol=1e-13, + atol=1e-13, + ) From 7ca959d08819e02019f29b2d30322d28f1f2aa09 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 23 Jun 2026 17:29:06 +0200 Subject: [PATCH 27/36] update kernels and layout --- .../_cuda/rank_genes/csr_tile_to_dense.cuh | 37 --- .../_cuda/rank_genes/rank_stats.cu | 53 +++- .../_cuda/sparse_extract/sparse_extract.cuh | 193 ++++++++++++++ .../_cuda/wilcoxon/wilcoxon.cu | 241 ++++++++++++++++++ .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 64 +---- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 28 +- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 99 ++++--- .../tools/_rank_genes_groups/__init__.py | 4 - .../tools/_rank_genes_groups/_core.py | 5 +- .../tools/_rank_genes_groups/_utils.py | 73 +++++- .../tools/_rank_genes_groups/_wilcoxon.py | 135 +++++++--- tests/test_rank_genes_groups_wilcoxon.py | 134 +++++++++- 12 files changed, 834 insertions(+), 232 deletions(-) delete mode 100644 src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh create mode 100644 src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh diff --git a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh b/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh deleted file mode 100644 index b6e881c4..00000000 --- a/src/rapids_singlecell/_cuda/rank_genes/csr_tile_to_dense.cuh +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include - -// Single-pass CSR-slice + densify: scatter column window [col_lb, col_ub) into -// a dense (n_cells, col_ub-col_lb) F-order double buffer, skipping the CSR -> -// CSC rebuild a `X[:, lb:ub].tocsc()` densify would do. -// -// `out` must be pre-zeroed; the atomicAdd also sums duplicate column indices -// (like scipy's sum_duplicates) -- bit-identical to dense materialization for -// canonical CSR. Output is always double; input dtype is templated. - -template -__global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, - const IndexT* __restrict__ indices, - const TData* __restrict__ data, - double* __restrict__ out, int col_lb, - int col_ub, int n_cells) { - const long long row = - static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - if (row >= n_cells) { - return; - } - const long long row_start = static_cast(indptr[row]); - const long long row_end = static_cast(indptr[row + 1]); - // Keep column ids in IndexT: narrowing a 64-bit IndexT to int would - // truncate large column ids and misplace writes. - const IndexT lb = static_cast(col_lb); - const IndexT ub = static_cast(col_ub); - for (long long k = row_start; k < row_end; ++k) { - const IndexT col = indices[k]; - if (col >= lb && col < ub) { - atomicAdd(&out[static_cast(col - lb) * n_cells + row], - static_cast(data[k])); - } - } -} diff --git a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu index db5f9ee8..6c07afcb 100644 --- a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu +++ b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu @@ -1,7 +1,7 @@ #include #include "../nb_types.h" -#include "csr_tile_to_dense.cuh" +#include "../sparse_extract/sparse_extract.cuh" using namespace nb::literals; @@ -103,6 +103,48 @@ static void def_csr_tile_to_dense(nb::module_& m) { "col_ub"_a, "stream"_a = 0); } +// CSC -> dense F-order (double) window densify, fused pass (column-major). +template +static void def_csc_tile_to_dense(nb::module_& m) { + m.def( + "csc_tile_to_dense", + [](gpu_array_c indptr, + gpu_array_c indices, + gpu_array_c data, + gpu_array_f out, int col_lb, int col_ub, + std::uintptr_t stream) { + const int n_cells = static_cast(out.shape(0)); + const int n_win = col_ub - col_lb; + if (n_cells <= 0 || n_win <= 0) { + return; + } + if (col_lb < 0) { + throw std::invalid_argument( + "csc_tile_to_dense: col_lb must be non-negative"); + } + if (indices.shape(0) != data.shape(0)) { + throw std::invalid_argument( + "csc_tile_to_dense: indices and data must have equal " + "length"); + } + if (out.ndim() != 2 || + static_cast(out.shape(1)) < n_win) { + throw std::invalid_argument( + "csc_tile_to_dense: out must be a (n_cells, >= col_ub - " + "col_lb) array"); + } + constexpr int CSC_TILE_BLOCK = 128; + csc_tile_to_dense_kernel + <<(n_win), CSC_TILE_BLOCK, 0, + (cudaStream_t)stream>>>(indptr.data(), indices.data(), + data.data(), out.data(), col_lb, + col_ub, n_cells); + CUDA_CHECK_LAST_ERROR(csc_tile_to_dense_kernel); + }, + "indptr"_a, "indices"_a, "data"_a, "out"_a, nb::kw_only(), "col_lb"_a, + "col_ub"_a, "stream"_a = 0); +} + template void register_bindings(nb::module_& m) { def_csr_tile_to_dense(m); @@ -114,6 +156,15 @@ void register_bindings(nb::module_& m) { def_csr_tile_to_dense(m); def_csr_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + m.def( "fdr_bh_reverse_cummin", [](gpu_array_c values, std::uintptr_t stream) { diff --git a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh new file mode 100644 index 00000000..2e4b02c8 --- /dev/null +++ b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh @@ -0,0 +1,193 @@ +#pragma once + +#include + +// ============================================================================ +// Shared CSR/CSC -> {compact CSC, dense} extraction kernels. +// +// Header-only templates used by the wilcoxon and rank_genes CUDA modules to +// land a gene-column window on the GPU in a column-usable layout. Two families: +// * compact CSC (csr_scatter_to_csc) -> sparse ranker (nnz only) +// * dense F-order (csr_tile_to_dense, extract) -> dense ranker (all values) +// ============================================================================ + +/** + * Scatter CSR nonzeros into compact CSC for columns [col_start, col_stop). + * write_pos[c - col_start] is the prefix-sum offset for column c; each thread + * atomically claims a unique destination slot. + * + * PRECONDITION: each row's `indices` must be sorted ascending -- the binary + * search for col_start and the `break` at col_stop depend on it; unsorted rows + * would silently drop or misplace nonzeros. Python dispatch calls + * `sort_indices()` before launching this kernel. + * + * `row_offset` is added to the local row index so a row-block rebased to a + * local [0, n_rows) range still records the correct global row id (out-of-core + * row-streaming OVR path). Defaults to 0 for full-matrix callers. + */ +template +__global__ void csr_scatter_to_csc_kernel( + const InT* __restrict__ data, const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, + InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + int col_start, int col_stop, int row_offset = 0) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + // Binary search for col_start (overflow-safe midpoint) + IndptrT lo = rs, hi = re; + while (lo < hi) { + IndptrT m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + for (IndptrT p = lo; p < re; ++p) { + int c = (int)indices[p]; + if (c >= col_stop) break; + int dest = atomicAdd(&write_pos[c - col_start], 1); + csc_vals[dest] = data[p]; + csc_row_idx[dest] = row_offset + row; + } +} + +// Single-pass CSR-slice + densify: scatter column window [col_lb, col_ub) into +// a dense (n_cells, col_ub-col_lb) F-order double buffer, skipping the CSR -> +// CSC rebuild a `X[:, lb:ub].tocsc()` densify would do. +// +// `out` must be pre-zeroed; the atomicAdd also sums duplicate column indices +// (like scipy's sum_duplicates) -- bit-identical to dense materialization for +// canonical CSR. Output is always double; input dtype is templated. +template +__global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, + const IndexT* __restrict__ indices, + const TData* __restrict__ data, + double* __restrict__ out, int col_lb, + int col_ub, int n_cells) { + const long long row = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (row >= n_cells) { + return; + } + const long long row_start = static_cast(indptr[row]); + const long long row_end = static_cast(indptr[row + 1]); + // Keep column ids in IndexT: narrowing a 64-bit IndexT to int would + // truncate large column ids and misplace writes. + const IndexT lb = static_cast(col_lb); + const IndexT ub = static_cast(col_ub); + for (long long k = row_start; k < row_end; ++k) { + const IndexT col = indices[k]; + if (col >= lb && col < ub) { + atomicAdd(&out[static_cast(col - lb) * n_cells + row], + static_cast(data[k])); + } + } +} + +// CSC column-window [col_lb, col_ub) -> dense F-order (double), single fused +// pass. One block per column; threads stride that column's nonzeros. Writes are +// column-major coalesced and need NO atomicAdd -- canonical CSC has a unique +// (col,row) per nonzero (the wilcoxon dispatch canonicalizes/sums first). This +// is the densify-from-CSC counterpart to csr_tile_to_dense_kernel. +// +// `out` must be pre-zeroed. `indptr` indexes columns; pass either full-matrix +// column pointers (with col_lb/col_ub) or a window rebased to [0, +// col_ub-col_lb). +template +__global__ void csc_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, + const IndexT* __restrict__ indices, + const TData* __restrict__ data, + double* __restrict__ out, int col_lb, + int col_ub, int n_cells) { + const int col = col_lb + static_cast(blockIdx.x); + if (col >= col_ub) return; + const long long col_local = blockIdx.x; + const IndptrT s = indptr[col]; + const IndptrT e = indptr[col + 1]; + for (IndptrT p = s + threadIdx.x; p < e; p += blockDim.x) { + const long long row = static_cast(indices[p]); + out[col_local * n_cells + row] = static_cast(data[p]); + } +} + +// CSR selected rows -> dense F-order. row_ids[tid] = source row; output column +// is (col - col_start), output row is tid. Requires sorted indices (binary +// search + break). Output must be pre-zeroed. +template +__global__ void csr_extract_dense_kernel(const T* __restrict__ data, + const int* __restrict__ indices, + const IndptrT* __restrict__ indptr, + const int* __restrict__ row_ids, + T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + + IndptrT lo = rs, hi = re; + while (lo < hi) { + IndptrT m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (IndptrT p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + tid] = data[p]; + } +} + +// CSR identity-mapped rows -> dense F-order; tolerates UNSORTED indices (full +// row scan, no binary search). One block per row. Output must be pre-zeroed. +template +__global__ void csr_extract_dense_identity_rows_unsorted_kernel( + const T* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int row = blockIdx.x; + if (row >= n_target) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + for (int p = rs + threadIdx.x; p < re; p += blockDim.x) { + int c = indices[p]; + if (c >= col_start && c < col_stop) { + out[(long long)(c - col_start) * n_target + row] = data[p]; + } + } +} + +/** + * Extract rows from CSC into dense F-order via a row lookup map. + * row_map[original_row] = output_row_index (or -1 to skip). + * One block per column. Output must be pre-zeroed. + */ +template +__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, + const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, + const int* __restrict__ row_map, + float* __restrict__ out, int n_target, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + IndptrT start = indptr[col]; + IndptrT end = indptr[col + 1]; + + for (IndptrT p = start + threadIdx.x; p < end; p += blockDim.x) { + int out_row = row_map[(int)indices[p]]; + if (out_row >= 0) { + out[(long long)col_local * n_target + out_row] = data[p]; + } + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 5b183b1f..18442bb3 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -3,6 +3,7 @@ #include #include +#include #include #include "../nb_types.h" @@ -122,6 +123,210 @@ static void launch_ovr_rank_dense_streaming( sync_streams(streams, "dense OVR streaming rank"); } +// Host-streaming dense OVR: same multi-stream pipeline as the host-CSC path +// (pinned host, round-robin streams, per-batch async H2D overlapping the rank) +// feeding the dense sort+rank above. Both layouts read into an F-order device +// block PER SUB-BATCH (the full array is never transposed): F-order is a +// contiguous memcpy; C-order is a strided cudaMemcpy2DAsync of the sub-batch +// tile then read into F-order. Input dtype is cast to float32 keys; group sums +// (+ nnz) are accumulated in f64 from the native-dtype staging for means/pts. +template +static void launch_ovr_rank_dense_host_streaming( + const T* h_X, bool f_order, const int* group_codes, double* rank_sums, + double* tie_corr, double* group_sums, double* group_nnz, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + const bool compute_stats = group_sums != nullptr; + compute_nnz = compute_nnz && (group_nnz != nullptr); + // F-order float32 input feeds the sort directly (no cast/transpose buffer). + const bool fast_keys = f_order && std::is_same::value; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) { + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + } + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + int sub_items_i32 = + checked_cub_items(sub_items, "Dense host OVR sub-batch"); + size_t cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(sub_items_i32, sub_batch_cols); + + // Clamp the stream count to the device memory budget (like the sparse + // launchers) so a tall/wide host-dense matrix shrinks the pipeline rather + // than OOMing on unbounded per-stream sort scratch. + size_t per_stream_bytes = + sub_items * (sizeof(T) + (fast_keys ? 0 : sizeof(float)) + + sizeof(float) + 2 * sizeof(int)) + + cub_temp_bytes + (size_t)(sub_batch_cols + 1) * sizeof(int) + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + (size_t)sub_batch_cols * sizeof(double) + + (compute_stats ? (size_t)n_groups * sub_batch_cols * sizeof(double) + : 0) + + (compute_nnz ? (size_t)n_groups * sub_batch_cols * sizeof(double) : 0); + n_streams = clamp_streams_by_budget(n_streams, per_stream_bytes, + rmm_available_device_bytes(0.8)); + + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; + // Best-effort pin of the host array for faster async H2D; on failure (pin + // caps / non-pinnable memory) proceed unpinned rather than raising. + HostRegisterGuard _pin(const_cast(h_X), + (size_t)n_rows * n_cols * sizeof(T), 0, + /*best_effort=*/true); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + struct StreamBuf { + T* d_stg; + float* block_f32; + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + bufs[s].d_stg = pool.alloc(sub_items); + bufs[s].block_f32 = fast_keys ? nullptr : pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_group_sums = + compute_stats + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = checked_int_product((size_t)n_rows, (size_t)sb_cols, + "Dense host OVR active sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + // H2D the column window on this stream (overlaps the prior batch rank). + if (f_order) { + cudaMemcpyAsync(buf.d_stg, h_X + (size_t)col * n_rows, + (size_t)sb_items * sizeof(T), + cudaMemcpyHostToDevice, stream); + } else { + cudaMemcpy2DAsync(buf.d_stg, (size_t)sb_cols * sizeof(T), h_X + col, + (size_t)n_cols * sizeof(T), + (size_t)sb_cols * sizeof(T), n_rows, + cudaMemcpyHostToDevice, stream); + } + + const float* keys_in; + if (fast_keys) { + keys_in = reinterpret_cast(buf.d_stg); + } else { + // dense_block_to_f32_kernel is grid-stride, so a bounded grid + // covers any sb_items (up to INT_MAX) with no overflow in the + // launch math and enough blocks to saturate the device. + const unsigned int grid = (unsigned int)std::min( + ((size_t)sb_items + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE, + 65535u); + dense_block_to_f32_kernel<<>>( + buf.d_stg, buf.block_f32, n_rows, sb_cols, f_order); + CUDA_CHECK_LAST_ERROR(dense_block_to_f32_kernel); + keys_in = buf.block_f32; + } + + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, stream, "dense host OVR segmented sort"); + + // gmem rank mode atomicAdds onto sub_rank_sums without self-zeroing, + // and the per-stream buffer is reused round-robin, so zero it first + // (the device-resident launcher does the same). + if (use_gmem) { + cuda_check(cudaMemsetAsync( + buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream), + "dense host OVR gmem rank_sums memset"); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + cuda_check( + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense host OVR rank_sums D2D copy"); + if (compute_tie_corr) { + cuda_check(cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream), + "dense host OVR tie_corr D2D copy"); + } + + // Group sums (+ nnz) for means/pts, in f64 from the native-dtype + // staging (matches the Aggregate path); fed to + // _fill_basic_stats_from_accumulators like the host-CSC path. + if (compute_stats) { + cudaMemsetAsync(buf.sub_group_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + if (compute_nnz) { + cudaMemsetAsync(buf.sub_group_nnz, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + dense_group_accumulate_kernel + <<>>( + buf.d_stg, group_codes, buf.sub_group_sums, + compute_nnz ? buf.sub_group_nnz : buf.sub_group_sums, + n_rows, sb_cols, n_groups, f_order, compute_nnz); + CUDA_CHECK_LAST_ERROR(dense_group_accumulate_kernel); + scatter_cols_2d(group_sums + col, buf.sub_group_sums, n_groups, + n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(group_nnz + col, buf.sub_group_nnz, n_groups, + n_cols, sb_cols, stream); + } + } + + col += sb_cols; + ++batch_idx; + } + + sync_streams(streams, "dense host OVR streaming"); +} + static void launch_ovo_rank_dense_tiered_unsorted_ref( const float* ref_data, const float* grp_data, const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, @@ -288,10 +493,46 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( sync_streams(streams, "dense OVO tiered rank"); } +template +static void def_ovr_rank_dense_host_streaming(nb::module_& m) { + m.def( + "ovr_rank_dense_host_streaming", + [](host_array buf, gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, + gpu_array_c group_sums, + gpu_array_c group_nnz, int n_rows, int n_cols, + int n_groups, bool f_order, bool compute_tie_corr, bool compute_nnz, + bool compute_stats, int sub_batch_cols) { + nb_require(buf.shape(0) == (size_t)n_rows * (size_t)n_cols, + "ovr_rank_host: buf length must be n_rows*n_cols"); + nb_require((int)group_codes.shape(0) == n_rows, + "ovr_rank_host: group_codes length must be n_rows"); + nb_require( + (int)rank_sums.shape(0) == n_groups && + (int)rank_sums.shape(1) == n_cols, + "ovr_rank_host: rank_sums shape must be (n_groups, n_cols)"); + nb_require((int)tie_corr.shape(0) == n_cols, + "ovr_rank_host: tie_corr length must be n_cols"); + launch_ovr_rank_dense_host_streaming( + buf.data(), f_order, group_codes.data(), rank_sums.data(), + tie_corr.data(), compute_stats ? group_sums.data() : nullptr, + compute_nnz ? group_nnz.data() : nullptr, n_rows, n_cols, + n_groups, compute_tie_corr, compute_nnz, sub_batch_cols); + }, + "buf"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, "group_sums"_a, + "group_nnz"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, + "f_order"_a, "compute_tie_corr"_a, "compute_nnz"_a, "compute_stats"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); +} + template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; + def_ovr_rank_dense_host_streaming(m); + def_ovr_rank_dense_host_streaming(m); + m.def( "ovo_rank_dense_tiered_unsorted_ref", [](gpu_array_f ref_data, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 4a659184..399f840c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -14,6 +14,7 @@ #include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR #include "../rmm_scratch.h" // rmm_allocate, RmmScratchPool, ScopedCudaBuffer +#include "../sparse_extract/sparse_extract.cuh" // csr_extract_dense* kernels // Host thread count for CPU-side CSR passes: hardware concurrency, capped. static inline int host_worker_count() { @@ -286,14 +287,18 @@ struct HostRegisterGuard { void* ptr = nullptr; HostRegisterGuard() = default; - HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) { + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0, + bool best_effort = false) { if (p && bytes > 0) { cudaError_t err = cudaHostRegister(p, bytes, flags); if (err != cudaSuccess) { // Already-registered memory belongs to another owner; use it // without unregistering here. Other failures mean mapped reads - // would be unsafe, so surface them immediately. - if (err == cudaErrorHostMemoryAlreadyRegistered) { + // would be unsafe, so surface them immediately -- unless the + // caller opts into best-effort pinning (the pin is only a + // transfer speedup; plain H2D still works unpinned). + if (err == cudaErrorHostMemoryAlreadyRegistered || + best_effort) { cudaGetLastError(); // clear sticky error flag } else { throw std::runtime_error( @@ -538,56 +543,3 @@ static inline void upload_linear_offsets(int* d_offsets, int n_segments, d_offsets, n_segments, stride); CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); } - -// ============================================================================ -// CSR → dense F-order extraction (templated on data type) -// ============================================================================ - -template -__global__ void csr_extract_dense_kernel(const T* __restrict__ data, - const int* __restrict__ indices, - const IndptrT* __restrict__ indptr, - const int* __restrict__ row_ids, - T* __restrict__ out, int n_target, - int col_start, int col_stop) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n_target) return; - - int row = row_ids[tid]; - IndptrT rs = indptr[row]; - IndptrT re = indptr[row + 1]; - - IndptrT lo = rs, hi = re; - while (lo < hi) { - IndptrT m = lo + ((hi - lo) >> 1); - if (indices[m] < col_start) - lo = m + 1; - else - hi = m; - } - - for (IndptrT p = lo; p < re; ++p) { - int c = indices[p]; - if (c >= col_stop) break; - out[(long long)(c - col_start) * n_target + tid] = data[p]; - } -} - -template -__global__ void csr_extract_dense_identity_rows_unsorted_kernel( - const T* __restrict__ data, const int* __restrict__ indices, - const int* __restrict__ indptr, T* __restrict__ out, int n_target, - int col_start, int col_stop) { - int row = blockIdx.x; - if (row >= n_target) return; - - int rs = indptr[row]; - int re = indptr[row + 1]; - - for (int p = rs + threadIdx.x; p < re; p += blockDim.x) { - int c = indices[p]; - if (c >= col_start && c < col_stop) { - out[(long long)(c - col_start) * n_target + row] = data[p]; - } - } -} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index a9042fed..9cfc81bd 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -2,6 +2,8 @@ #include +#include "../sparse_extract/sparse_extract.cuh" + /** * Build CUB segmented-sort ranges for HUGE-band groups. Ranges point into the * original dense group layout so the presorted rank kernel reads normal @@ -23,32 +25,6 @@ __global__ void build_huge_seg_offsets_kernel( ends[idx] = base + grp_offsets[g + 1]; } -/** - * Extract rows from CSC into dense F-order via a row lookup map. - * row_map[original_row] = output_row_index (or -1 to skip). - * One block per column. Output must be pre-zeroed. - */ -template -__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, - const IndexT* __restrict__ indices, - const IndptrT* __restrict__ indptr, - const int* __restrict__ row_map, - float* __restrict__ out, int n_target, - int col_start) { - int col_local = blockIdx.x; - int col = col_start + col_local; - - IndptrT start = indptr[col]; - IndptrT end = indptr[col + 1]; - - for (IndptrT p = start + threadIdx.x; p < end; p += blockDim.x) { - int out_row = row_map[(int)indices[p]]; - if (out_row >= 0) { - out[(long long)col_local * n_target + out_row] = data[p]; - } - } -} - /** * Sizing knobs for LARGE-band dispatch: when the largest group fits in shared * memory, a fused bitonic-sort + binary-search kernel handles the group per diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index 882d2e0e..0973757e 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -1,5 +1,7 @@ #pragma once +#include "../sparse_extract/sparse_extract.cuh" + /** Count nonzeros per column from CSR. One thread per row. */ template __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, @@ -16,48 +18,6 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, } } -/** - * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). - * write_pos[c - col_start] is the prefix-sum offset for column c; each thread - * atomically claims a unique destination slot. - * - * PRECONDITION: each row's `indices` must be sorted ascending -- the binary - * search for col_start and the `break` at col_stop depend on it; unsorted rows - * would silently drop or misplace nonzeros. Python dispatch calls - * `sort_indices()` before launching this kernel. - * - * `row_offset` is added to the local row index so a row-block rebased to a - * local [0, n_rows) range still records the correct global row id (out-of-core - * row-streaming OVR path). Defaults to 0 for full-matrix callers. - */ -template -__global__ void csr_scatter_to_csc_kernel( - const InT* __restrict__ data, const IndexT* __restrict__ indices, - const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, - InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, - int col_start, int col_stop, int row_offset = 0) { - int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= n_rows) return; - IndptrT rs = indptr[row]; - IndptrT re = indptr[row + 1]; - // Binary search for col_start (overflow-safe midpoint) - IndptrT lo = rs, hi = re; - while (lo < hi) { - IndptrT m = lo + ((hi - lo) >> 1); - if (indices[m] < col_start) - lo = m + 1; - else - hi = m; - } - for (IndptrT p = lo; p < re; ++p) { - int c = (int)indices[p]; - if (c >= col_stop) break; - int dest = atomicAdd(&write_pos[c - col_start], 1); - csc_vals[dest] = data[p]; - csc_row_idx[dest] = row_offset + row; - } -} - // CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). // // Decide smem-vs-gmem for the DENSE OVR rank kernel. Per-block accumulator is @@ -113,3 +73,58 @@ __global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, out[i] = i; } } + +/** + * Read one transferred dense column-batch (native dtype `T`) into float32 in + * F-order (column-major), the layout the segmented sort expects. Operates on a + * single sub-batch (n_rows x sb_cols) only -- the full array is never + * reordered/transposed. + * f_order=true : staging is already F-order -> identity cast. + * f_order=false: staging is C-order (n_rows x sb_cols, row-major); each + * element is read into its F-order slot while casting. + * Grid-stride over n_rows*sb_cols elements. + */ +template +__global__ void dense_block_to_f32_kernel(const T* __restrict__ stg, + float* __restrict__ out, int n_rows, + int sb_cols, bool f_order) { + const long long total = (long long)n_rows * sb_cols; + const long long stride = (long long)gridDim.x * blockDim.x; + for (long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x; + idx < total; idx += stride) { + if (f_order) { + out[idx] = (float)stg[idx]; + } else { + int col = (int)(idx / n_rows); + int row = (int)(idx % n_rows); + out[idx] = (float)stg[(long long)row * sb_cols + col]; + } + } +} + +/** + * Accumulate per-(group, column) sums (+ optional nnz) from a transferred dense + * column-batch, reading the NATIVE dtype staging in f64 so means match the + * Aggregate path (the f32 cast is only for ranking). One block per column. + * `group_sums`/`group_nnz` are this batch's (n_groups x sb_cols) buffers and + * must be pre-zeroed. Mirrors the sparse cast+accumulate the CSC host path + * runs. + */ +template +__global__ void dense_group_accumulate_kernel( + const T* __restrict__ stg, const int* __restrict__ group_codes, + double* __restrict__ group_sums, double* __restrict__ group_nnz, int n_rows, + int sb_cols, int n_groups, bool f_order, bool compute_nnz) { + int col = blockIdx.x; + if (col >= sb_cols) return; + for (int row = threadIdx.x; row < n_rows; row += blockDim.x) { + int g = group_codes[row]; + if (g < 0 || g >= n_groups) continue; + double v = f_order ? (double)stg[(long long)col * n_rows + row] + : (double)stg[(long long)row * sb_cols + col]; + atomicAdd(&group_sums[(long long)g * sb_cols + col], v); + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(long long)g * sb_cols + col], 1.0); + } + } +} diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index c2c7f1b3..71a43f9f 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -67,7 +67,6 @@ def rank_genes_groups( return_u_values: bool = False, layer: str | None = None, chunk_size: int | None = None, - pre_load: bool = False, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, skip_empty_groups: bool = False, @@ -160,8 +159,6 @@ def rank_genes_groups( `'wilcoxon_binned'`. Default is 512 for `'wilcoxon'`. For `'wilcoxon_binned'` the default is sized dynamically based on ``n_groups`` and ``n_bins`` to keep histogram memory stable. - pre_load - Pre-load the data into GPU memory. Used only for `'wilcoxon'`. n_bins Number of histogram bins for `'wilcoxon_binned'`. Higher values give a better approximation at slightly increased cost. Default is 1000 @@ -268,7 +265,6 @@ def rank_genes_groups( use_raw=use_raw, layer=layer, comp_pts=pts, - pre_load=pre_load, skip_empty_groups=skip_empty_groups, ) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index 4991b9ce..89a5dfde 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -46,7 +46,6 @@ def __init__( use_raw: bool | None = None, layer: str | None = None, comp_pts: bool = False, - pre_load: bool = False, skip_empty_groups: bool = False, ) -> None: if groups == "all" or groups is None: @@ -101,8 +100,6 @@ def __init__( self.X = self.X[:, mask_var] self.var_names = self.var_names[mask_var] - self.pre_load = pre_load - self.ireference = None if reference != "rest": self.ireference = int(np.where(self.groups_order == str(reference))[0][0]) @@ -391,7 +388,7 @@ def compute_statistics( # rank each stored nnz once, so they must see scanpy's summed view. self.X = _canonicalize_sparse(self.X) self._sparse_negative_fallback = _sparse_has_negative(self.X) - if self.pre_load or method in { + if method in { "t-test", "t-test_overestim_var", "wilcoxon_binned", diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 796354d0..913abe89 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -7,8 +7,6 @@ import numpy as np import scipy.sparse as sp -from rapids_singlecell.preprocessing._utils import _sparse_to_dense - if TYPE_CHECKING: import pandas as pd from numpy.typing import NDArray @@ -146,20 +144,31 @@ def _choose_chunk_size(requested: int | None) -> int: def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray: """ - Extract columns from a CSC matrix via direct indptr pointer slicing. + Densify a CSC column window [start, stop) into an F-order float64 block via + the fused ``csc_tile_to_dense`` kernel (column-major, coalesced, no atomics). - Works for both scipy and CuPy CSC matrices. Much faster than - ``X[:, start:stop]`` which rebuilds index arrays internally. + Slices the window by indptr pointers so only that window's nonzeros are + touched (and, for host CSC, transferred). Works for scipy and CuPy CSC. """ + from rapids_singlecell._cuda import _rank_stats_cuda as _rs + s_ptr = int(X_csc.indptr[start]) e_ptr = int(X_csc.indptr[stop]) - chunk_data = cp.asarray(X_csc.data[s_ptr:e_ptr]) - chunk_indices = cp.asarray(X_csc.indices[s_ptr:e_ptr]) - chunk_indptr = cp.asarray(X_csc.indptr[start : stop + 1] - s_ptr) - csc_chunk = cpsp.csc_matrix( - (chunk_data, chunk_indices, chunk_indptr), shape=(n_rows, stop - start) - ) - return _sparse_to_dense(csc_chunk, order="F").astype(cp.float64) + out = cp.zeros((n_rows, stop - start), dtype=cp.float64, order="F") + if e_ptr > s_ptr: + chunk_data = cp.asarray(X_csc.data[s_ptr:e_ptr]) + chunk_indices = cp.asarray(X_csc.indices[s_ptr:e_ptr]) + chunk_indptr = cp.asarray(X_csc.indptr[start : stop + 1] - s_ptr) + _rs.csc_tile_to_dense( + chunk_indptr, + chunk_indices, + chunk_data, + out, + col_lb=0, + col_ub=stop - start, + stream=cp.cuda.get_current_stream().ptr, + ) + return out def _csr_tile_to_dense_block(X, start: int, stop: int) -> cp.ndarray: @@ -203,12 +212,48 @@ def _get_column_block(X, start: int, stop: int) -> cp.ndarray: return _csc_columns_to_gpu(X, start, stop, X.shape[0]) case sp.spmatrix() | sp.sparray(): chunk = cpsp.csc_matrix(X[:, start:stop].tocsc()) - return _sparse_to_dense(chunk, order="F").astype(cp.float64) + return _csc_columns_to_gpu(chunk, 0, chunk.shape[1], X.shape[0]) case cpsp.csc_matrix(): return _csc_columns_to_gpu(X, start, stop, X.shape[0]) case cpsp.spmatrix(): - return _sparse_to_dense(X[:, start:stop], order="F").astype(cp.float64) + chunk = cpsp.csc_matrix(X[:, start:stop].tocsc()) + return _csc_columns_to_gpu(chunk, 0, chunk.shape[1], X.shape[0]) case np.ndarray() | cp.ndarray(): return cp.asarray(X[:, start:stop], dtype=cp.float64, order="F") case _: raise ValueError(f"Unsupported matrix type: {type(X)}") + + +def _ovr_dense_block_f32(X, start: int, stop: int) -> cp.ndarray: + """OVR (vs-rest): ALL cells x gene-window, F-order float32. + + For sparse X (the negative-values dense fallback) the window is densified on + the fly via the shared CSR/CSC densify path (`_get_column_block`), so no + full-matrix dense materialization happens. + """ + if isinstance(X, np.ndarray | cp.ndarray): + return cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") + if sp.issparse(X) or cpsp.issparse(X): + block = _get_column_block(X, start, stop) # float64 F-order chunk + return cp.asfortranarray(block.astype(cp.float32, copy=False)) + raise TypeError(f"Expected dense matrix, got {type(X)}") + + +def _ovo_dense_block(X, row_ids: np.ndarray, start: int, stop: int) -> cp.ndarray: + """OVO (with-reference): a ROW SUBSET (`row_ids`) x gene-window, F-order. + + OVO ranks the reference group against each other group, so it materializes + only the selected rows -- unlike `_ovr_dense_block_f32`, which takes all + cells. + """ + if isinstance(X, np.ndarray): + return cp.asarray(X[row_ids, start:stop], order="F") + if isinstance(X, cp.ndarray): + rows = cp.asarray(row_ids, dtype=cp.int32) + return cp.asfortranarray(X[rows, start:stop]) + if isinstance(X, sp.spmatrix | sp.sparray): + return cp.asarray(X[row_ids][:, start:stop].toarray(), order="F") + if cpsp.issparse(X): + rows = cp.asarray(row_ids, dtype=cp.int32) + return cp.asfortranarray(X[rows][:, start:stop].toarray()) + raise TypeError(f"Unsupported matrix type: {type(X)}") diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 8189a0c2..892d48db 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -12,7 +12,14 @@ from rapids_singlecell._cuda import _wilcoxon_cuda as _wc from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs -from ._utils import EPS, MIN_GROUP_SIZE_WARNING, _choose_chunk_size, _get_column_block +from ._utils import ( + EPS, + MIN_GROUP_SIZE_WARNING, + _choose_chunk_size, + _get_column_block, + _ovo_dense_block, + _ovr_dense_block_f32, +) if TYPE_CHECKING: from numpy.typing import NDArray @@ -32,7 +39,13 @@ def _maybe_preload_host_dense(rg: _RankGenes) -> None: - """Preload moderate host-dense matrices to avoid repeated chunk transfers.""" + """Preload a moderate host-dense matrix to the GPU (OVO only). + + OVO caches the sorted reference group on the device and ranks every other + group against it, so staging the matrix on the GPU once is intended -- it + avoids re-streaming the shared reference per group. The OVR path instead + streams column chunks from host (see ``ovr_rank_dense_host_streaming``). + """ X = rg.X if not isinstance(X, np.ndarray) or X.size == 0: return @@ -72,38 +85,6 @@ def _maybe_preload_host_dense(rg: _RankGenes) -> None: rg.X = X_gpu -def _get_dense_column_block_f32(X, start: int, stop: int) -> cp.ndarray: - """Extract a dense column block as F-order float32 CuPy memory. - - For sparse X (the negative-values dense fallback) the column window is - densified on the fly via the shared CSR/CSC densify path, so no full-matrix - dense materialization happens. - """ - if isinstance(X, np.ndarray | cp.ndarray): - return cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") - if sp.issparse(X) or cpsp.issparse(X): - block = _get_column_block(X, start, stop) # float64 F-order chunk - return cp.asfortranarray(block.astype(cp.float32, copy=False)) - raise TypeError(f"Expected dense matrix, got {type(X)}") - - -def _extract_dense_rows_cols( - X, row_ids: np.ndarray, start: int, stop: int -) -> cp.ndarray: - """Extract a bounded row/column block as F-order CuPy dense memory.""" - if isinstance(X, np.ndarray): - return cp.asarray(X[row_ids, start:stop], order="F") - if isinstance(X, cp.ndarray): - rows = cp.asarray(row_ids, dtype=cp.int32) - return cp.asfortranarray(X[rows, start:stop]) - if isinstance(X, sp.spmatrix | sp.sparray): - return cp.asarray(X[row_ids][:, start:stop].toarray(), order="F") - if cpsp.issparse(X): - rows = cp.asarray(row_ids, dtype=cp.int32) - return cp.asfortranarray(X[rows][:, start:stop].toarray()) - raise TypeError(f"Unsupported matrix type: {type(X)}") - - def _choose_wilcoxon_chunk_size(requested: int | None, n_genes: int) -> int: if requested is not None: return _choose_chunk_size(requested) @@ -363,7 +344,7 @@ def _device_sparse_arrays_f32(X): """Cast device-sparse arrays for the Wilcoxon kernels. Wilcoxon ranking sorts float32 keys on every path -- the sparse fast paths - AND the dense fallback (``_get_dense_column_block_f32``); the CUB segmented + AND the dense fallback (``_ovr_dense_block_f32``); the CUB segmented sort is float-keyed throughout. Casting ``X.data`` to float32 here therefore does not diverge from any float64 ranking path, because there is none. This only loses precision when preprocessing ran in float64; float32-preprocessed @@ -445,6 +426,9 @@ def _column_totals_for_host_matrix( "Wilcoxon sparse input must be CSR or CSC; refusing hidden " f"full-matrix conversion from {X.format!r}." ) + elif isinstance(X, np.ndarray): + sums = X.sum(axis=0, dtype=np.float64) + nnz = (X != 0).sum(axis=0).astype(np.float64) if compute_nnz else None else: raise TypeError(f"Unsupported host matrix type: {type(X)}") @@ -478,7 +462,10 @@ def wilcoxon( return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" - _maybe_preload_host_dense(rg) + # OVO caches the reference on the GPU and ranks each group against it, so + # preloading host-dense input is intended. OVR streams chunks from host. + if rg.ireference is not None: + _maybe_preload_host_dense(rg) # Aggregate if on GPU, else defer to chunks. rg._basic_stats() X = rg.X @@ -702,6 +689,76 @@ def _wilcoxon_vs_rest( group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev + # Host dense: stream column chunks from host with the CSC-style pipeline + # (per-batch H2D overlapping the rank), instead of moving the whole array to + # the GPU. Ranking + group sums/nnz come back in one streamed pass. + if isinstance(X, np.ndarray): + # Only float32/float64 host bindings exist; cast int/bool/uint/float16 + # to float32 (mirrors the sparse paths) rather than raising a TypeError. + if X.dtype.kind != "f" or X.dtype.itemsize < 4: + X = X.astype(np.float32) + if X.flags.f_contiguous: + buf, f_order = X.ravel(order="K"), True + elif X.flags.c_contiguous: + buf, f_order = X.ravel(order="K"), False + else: + buf, f_order = np.ascontiguousarray(X).ravel(order="K"), False + compute_nnz = rg.comp_pts + compute_stats = rg._compute_stats_in_chunks + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = ( + cp.empty(n_total_genes, dtype=cp.float64) + if tie_correct + else cp.ones(n_total_genes, dtype=cp.float64) + ) + stats_shape = (n_groups, n_total_genes) if compute_stats else (1, 1) + group_sums = cp.empty(stats_shape, dtype=cp.float64) + group_nnz = cp.empty( + (n_groups, n_total_genes) if (compute_stats and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + _wc.ovr_rank_dense_host_streaming( + buf, + group_codes_gpu, + rank_sums, + tie_corr, + group_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + f_order=f_order, + compute_tie_corr=tie_correct, + compute_nnz=compute_stats and compute_nnz, + compute_stats=compute_stats, + sub_batch_cols=OVR_DENSE_SUB_BATCH, + ) + if compute_stats: + total_sums, total_nnz = _host_ovr_totals_if_needed( + X, rg.group_codes, n_groups, compute_nnz=compute_nnz + ) + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes, + n_cells=n_cells, + total_sums=total_sums, + total_nnz=total_nnz, + ) + scores, p_values = _ovr_z_pvals( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + scores_host = scores.get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) all_scores: dict[int, list] = {i: [] for i in range(n_groups)} @@ -722,7 +779,7 @@ def _wilcoxon_vs_rest( ) block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) else: - block_f32 = _get_dense_column_block_f32(X, start, stop) + block_f32 = _ovr_dense_block_f32(X, start, stop) n_cols = stop - start rank_sums = cp.empty((n_groups, n_cols), dtype=cp.float64) @@ -1059,8 +1116,8 @@ def _wilcoxon_with_reference( stop = min(start + chunk_width, n_total_genes) n_cols = stop - start - ref_block = _extract_dense_rows_cols(X, ref_row_ids, start, stop) - grp_block = _extract_dense_rows_cols(X, all_grp_row_ids, start, stop) + ref_block = _ovo_dense_block(X, ref_row_ids, start, stop) + grp_block = _ovo_dense_block(X, all_grp_row_ids, start, stop) _fill_ovo_chunk_stats( rg, diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 9f2f7a66..baa2c06b 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -670,8 +670,7 @@ def test_wilcoxon_ovo_host_csr_unsorted_indices_match_sorted(): "cupy_csc", ], ) -@pytest.mark.parametrize("pre_load", [False, True]) -def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt, pre_load): +def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=120) _make_nonnegative(adata_gpu) @@ -687,7 +686,7 @@ def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt, pre_load): "tie_correct": True, "n_genes": 5, } - rsc.tl.rank_genes_groups(adata_gpu, **kw, pre_load=pre_load) + rsc.tl.rank_genes_groups(adata_gpu, **kw) sc.tl.rank_genes_groups(adata_cpu, **kw) gpu_result = adata_gpu.uns["rank_genes_groups"] @@ -834,6 +833,127 @@ def test_wilcoxon_ovr_many_groups_gmem_formats_agree(tie_correct): ) +# Regression guard for the host-dense OVR streaming launcher: in gmem rank mode +# the rank kernel atomicAdds onto its per-stream rank-sum buffer without +# self-zeroing, and those buffers are reused round-robin -- so the launcher must +# zero them per sub-batch. This only bites past the DENSE gmem flip +# (n_groups > ~6112) AND with enough genes (> N_STREAMS*sub_batch_cols = 256) +# that the per-stream buffers actually wrap; the formats-agree gmem test above +# uses n_genes=6 (one batch, fresh buffer) and cannot see it. Compares the host +# (numpy) dense path against the device (cupy) dense + sparse paths, which zero +# the gmem buffer correctly. +@pytest.mark.filterwarnings("ignore::RuntimeWarning") # 6200 tiny groups warn +def test_wilcoxon_ovr_dense_gmem_host_streaming_buffer_reuse(): + adata = _make_sized_groups_adata([2] * 6200, n_genes=400, seed=7) + ref = None + for fmt in ("cupy_dense", "numpy_dense", "cupy_csr"): + a = adata.copy() + a.X = _to_format(adata.X, fmt) + rsc.tl.rank_genes_groups( + a, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + ) + r = a.uns["rank_genes_groups"] + cur = { + field: np.vstack( + [np.asarray(r[field][n], dtype=float) for n in r[field].dtype.names] + ) + for field in ("scores", "pvals") + } + if ref is None: + ref = cur + continue + for field in ("scores", "pvals"): + np.testing.assert_allclose( + cur[field], ref[field], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + +# Host-dense OVR has only float32/float64 nanobind overloads; integer/bool/uint/ +# float16 numpy must be cast to float32 (mirrors the sparse path) rather than +# raising a TypeError. +@pytest.mark.parametrize( + "data_dtype", [np.int32, np.int64, np.uint16, np.float16, bool] +) +def test_wilcoxon_dense_nonfloat_data_matches_float32(data_dtype): + rng = np.random.default_rng(5) + n_obs, n_genes = 120, 8 + counts = rng.integers(0, 5, size=(n_obs, n_genes)) + if data_dtype is bool: + counts = counts > 2 + typed = np.ascontiguousarray(counts.astype(data_dtype)) + f32 = np.ascontiguousarray(counts.astype(np.float32)) + labels = np.array([f"{i % 3}" for i in range(n_obs)]) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{j}" for j in range(n_genes)]) + + def run(arr): + adata = sc.AnnData(X=arr, obs=obs.copy(), var=var.copy()) + adata.uns["log1p"] = {"base": None} + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + n_genes=n_genes, + ) + return adata.uns["rank_genes_groups"] + + r_typed = run(typed) + r_f32 = run(f32) + for grp in r_typed["scores"].dtype.names: + np.testing.assert_array_equal( + np.asarray(r_typed["scores"][grp], dtype=float), + np.asarray(r_f32["scores"][grp], dtype=float), + ) + + +# F-contiguous host-dense numpy hits the f_order=True branch of the host- +# streaming launcher: float32 -> the reinterpret-cast fast path (no cast kernel), +# float64 -> dense_block_to_f32_kernel's identity branch. Every numpy_dense +# fixture elsewhere is C-order, so this is the only coverage of that branch. +# AnnData preserves F-order, so an F-contiguous X reaches the path; result must +# match the C-order run on identical data. +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_wilcoxon_ovr_fortran_order_host_dense_matches_c_order(dtype): + rng = np.random.default_rng(11) + X = np.abs(rng.standard_normal((300, 40))).astype(dtype) + X[X < 0.3] = 0.0 + labels = rng.integers(0, 5, 300) + obs = pd.DataFrame({"group": pd.Categorical([f"g{c}" for c in labels])}) + var = pd.DataFrame(index=[f"g{j}" for j in range(40)]) + + def run(arr): + adata = sc.AnnData(X=arr, obs=obs.copy(), var=var.copy()) + adata.uns["log1p"] = {"base": None} + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + ) + return adata.uns["rank_genes_groups"] + + xf = np.asfortranarray(X) + assert xf.flags.f_contiguous + r_f = run(xf) + r_c = run(np.ascontiguousarray(X)) + for field in ("scores", "pvals", "logfoldchanges"): + for grp in r_f[field].dtype.names: + np.testing.assert_array_equal( + np.asarray(r_f[field][grp], dtype=float), + np.asarray(r_c[field][grp], dtype=float), + ) + + # Regression guard for a shared-memory OOB write in the host sparse OVR # cast-and-accumulate kernel: it placed the per-group nnz accumulator at a fixed # 2*n_groups smem offset, but cast_accumulate_smem_config packs only the enabled @@ -931,9 +1051,8 @@ def test_wilcoxon_ovr_many_groups_gmem_pts_formats_agree(): ], ) @pytest.mark.parametrize("tie_correct", [False, True]) -@pytest.mark.parametrize("pre_load", [False, True]) def test_rank_genes_groups_wilcoxon_subset_matches_scanpy( - groups, reference, tie_correct, pre_load + groups, reference, tie_correct ): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=8, n_centers=5, n_observations=200) @@ -948,7 +1067,6 @@ def test_rank_genes_groups_wilcoxon_subset_matches_scanpy( reference=reference, use_raw=False, tie_correct=tie_correct, - pre_load=pre_load, ) sc.tl.rank_genes_groups( adata_cpu, @@ -1066,8 +1184,7 @@ def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): @pytest.mark.parametrize("reference", ["rest", "1"]) -@pytest.mark.parametrize("pre_load", [True, False]) -def test_rank_genes_groups_wilcoxon_pts(reference, pre_load): +def test_rank_genes_groups_wilcoxon_pts(reference): """Test that pts (fraction of cells expressing) is computed correctly.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) @@ -1084,7 +1201,6 @@ def test_rank_genes_groups_wilcoxon_pts(reference, pre_load): pts=True, tie_correct=False, reference=reference, - pre_load=pre_load, ) sc.tl.rank_genes_groups( adata_cpu, From 58bbd4ad528cfc1d30d6108703f61e0124bd4817 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 24 Jun 2026 15:17:37 +0200 Subject: [PATCH 28/36] remove small and tiny and speed up larger paths --- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 461 ++++-------------- .../_cuda/wilcoxon/wilcoxon.cu | 6 +- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 15 +- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 20 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 15 +- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 121 ++--- tests/test_rank_genes_groups_wilcoxon.py | 58 +-- 7 files changed, 169 insertions(+), 527 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index 65916525..8bbf4bd2 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -96,64 +96,86 @@ __device__ __forceinline__ OvoRank ovo_mid_rank(const float* ref, int n_ref, } // ============================================================================ -// Parallel tie correction — all threads collaborate. +// Amortized tie correction for the LARGE/HUGE bands (group is SORTED). // -// Accumulates t^3 - t per unique value of the combined (ref, grp) arrays via -// two passes: ref uniques (counted in both), then grp uniques absent from ref. -// Incremental binary search bounds exploit per-thread-stride monotonicity. -// Caller must __syncthreads() first. warp_buf (32 doubles) reused for -// reduction. +// Adds only the group-only / ref-overlap delta on top of the precomputed +// reference base ref_tie_sums[col] (= ref_tie_sum_kernel), exactly like the +// MEDIUM band. Iterates the sorted group's UNIQUE values only -- one binary +// search into the ref per unique value -- so the reference is NOT rescanned per +// group (the cost a naive full-combined-rescan would pay). This makes tie +// correction O(n_grp_unique * log n_ref) instead of O(n_ref) per group, which +// dominates whenever the reference is large (e.g. a perturbation control or a +// big cluster). Bit-identical: same per-value (t^3 - t) terms, just +// reassociated against the shared ref base. // ============================================================================ - -__device__ __forceinline__ void compute_tie_correction_parallel( +__device__ __forceinline__ void compute_tie_delta_sorted_grp( const float* ref_col, int n_ref, const float* grp_col, int n_grp, - double* warp_buf, double* out) { - double local_tie = 0.0; - - // Pass 1: unique values in ref_col, counted in both arrays - int grp_lb = 0, grp_ub = 0; - for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { - if (i == 0 || ref_col[i] != ref_col[i - 1]) { - float v = ref_col[i]; - - int cnt_ref = sorted_upper_bound(ref_col, i + 1, n_ref, v) - i; - - int lb = sorted_lower_bound(grp_col, grp_lb, n_grp, v); - grp_lb = lb; - grp_ub = sorted_upper_bound(grp_col, grp_ub > lb ? grp_ub : lb, - n_grp, v); - int cnt_grp = grp_ub - lb; - - int cnt = cnt_ref + cnt_grp; - if (cnt > 1) { - double t = (double)cnt; - local_tie += t * t * t - t; - } - } - } - - // Pass 2: unique values in grp_col that are absent from ref_col - int ref_lb = 0; + double ref_base, double* warp_buf, double* out) { + double local = 0.0; for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + // run-start of a unique value in the sorted group if (i == 0 || grp_col[i] != grp_col[i - 1]) { float v = grp_col[i]; - - int lo = sorted_lower_bound(ref_col, ref_lb, n_ref, v); - ref_lb = lo; - - if (lo >= n_ref || ref_col[lo] != v) { - // Value absent from ref — count in grp only - int cnt = sorted_upper_bound(grp_col, i + 1, n_grp, v) - i; - if (cnt > 1) { - double t = (double)cnt; - local_tie += t * t * t - t; - } + int gub = sorted_upper_bound(grp_col, i + 1, n_grp, v); + double cg = (double)(gub - i); + int rlo = sorted_lower_bound(ref_col, 0, n_ref, v); + int rub = sorted_upper_bound(ref_col, rlo, n_ref, v); + double cr = (double)(rub - rlo); + double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0; + local += group_tie; + if (cr > 0.0) { + double combined = cr + cg; + double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0; + local += combined * combined * combined - combined - ref_tie - + group_tie; } } } + double tie = wilcoxon_block_sum(local, warp_buf); + if (threadIdx.x == 0) + *out = finalize_tie_corr(n_ref + n_grp, ref_base + tie); +} - double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); - if (threadIdx.x == 0) *out = finalize_tie_corr(n_ref + n_grp, tie_sum); +// ============================================================================ +// No-tie fast path (tie_correct=False, the default). Ranks each group value +// against the sorted REFERENCE only, via the Mann-Whitney U identity: +// R_g = n_grp(n_grp+1)/2 + Σ_{g values}(#ref_below + 0.5·#ref_equal) +// The group-internal ranks collapse to the closed form, so the group block +// needs NO sort — each value does a full binary search into sorted ref. This +// eliminates the group segmented sort, ~half of dense-OVO time (see profiling). +// rank_sums are exact half-integers, so this matches the tiered path +// bit-for-bit. Grid: (n_cols, n_groups), Block: (tpb,). grp_dense is UNSORTED. +// ============================================================================ +__global__ void ovo_rank_dense_vs_ref_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, + int n_ref, int n_all_grp, int n_cols, int n_groups) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int n_grp = grp_offsets[grp + 1] - g_start; + if (n_grp == 0) { + if (threadIdx.x == 0) rank_sums[(size_t)grp * n_cols + col] = 0.0; + return; + } + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + + double local_sum = 0.0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + int n_lt = sorted_lower_bound(ref_col, 0, n_ref, v); + int n_eq = sorted_upper_bound(ref_col, n_lt, n_ref, v) - n_lt; + local_sum += (double)n_lt + 0.5 * (double)n_eq; + } + __shared__ double warp_buf[32]; + double total = wilcoxon_block_sum(local_sum, warp_buf); + if (threadIdx.x == 0) { + rank_sums[(size_t)grp * n_cols + col] = + total + (double)n_grp * ((double)n_grp + 1.0) / 2.0; + } } // ============================================================================ @@ -165,7 +187,8 @@ __device__ __forceinline__ void compute_tie_correction_parallel( __global__ void ovo_rank_huge_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, - const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, int skip_n_grp_le /*= 0*/) { int col = blockIdx.x; @@ -208,8 +231,11 @@ __global__ void ovo_rank_huge_kernel( if (!compute_tie_corr) return; __syncthreads(); - compute_tie_correction_parallel(ref_col, n_ref, grp_col, n_grp, warp_buf, - &tie_corr[grp * n_cols + col]); + // grp_col is sorted (CUB segmented sort upstream): amortize the ref tie + // contribution via the precomputed base instead of rescanning the ref. + compute_tie_delta_sorted_grp(ref_col, n_ref, grp_col, n_grp, + ref_tie_sums[col], warp_buf, + &tie_corr[grp * n_cols + col]); } // ============================================================================ @@ -220,12 +246,13 @@ __global__ void ovo_rank_huge_kernel( // ============================================================================ __global__ void ovo_rank_large_kernel( - const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted - const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) - // unsorted - const int* __restrict__ grp_offsets, // (n_groups + 1,) - double* __restrict__ rank_sums, // (n_groups, n_cols) row-major - double* __restrict__ tie_corr, // (n_groups, n_cols) row-major + const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted + const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) + // unsorted + const int* __restrict__ grp_offsets, // (n_groups + 1,) + const double* __restrict__ ref_tie_sums, // (n_cols,) ref tie base + double* __restrict__ rank_sums, // (n_groups, n_cols) row-major + double* __restrict__ tie_corr, // (n_groups, n_cols) row-major int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, int large_padded, int skip_n_grp_le /*= 0*/) { int col = blockIdx.x; @@ -284,9 +311,11 @@ __global__ void ovo_rank_large_kernel( if (!compute_tie_corr) return; __syncthreads(); - // grp_smem is sorted here - compute_tie_correction_parallel(ref_col, n_ref, grp_smem, n_grp, warp_buf, - &tie_corr[grp * n_cols + col]); + // grp_smem is sorted here: amortize the ref tie contribution via the + // precomputed base instead of rescanning the ref per group. + compute_tie_delta_sorted_grp(ref_col, n_ref, grp_smem, n_grp, + ref_tie_sums[col], warp_buf, + &tie_corr[grp * n_cols + col]); } // ============================================================================ @@ -319,70 +348,6 @@ __global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, if (threadIdx.x == 0) ref_tie_sums[col] = total; } -__global__ void ovo_rank_small_kernel( - const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, - const int* __restrict__ grp_offsets, - const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, - double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, int skip_n_grp_le) { - int col = blockIdx.x; - int grp = blockIdx.y; - if (col >= n_cols || grp >= n_groups) return; - - int g_start = grp_offsets[grp]; - int g_end = grp_offsets[grp + 1]; - int n_grp = g_end - g_start; - if (n_grp <= skip_n_grp_le || n_grp > OVO_SMALL_MAX) return; - - __shared__ float grp_smem[OVO_SMALL_MAX]; - __shared__ double warp_buf[WARP_REDUCE_BUF]; - - const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; - const float POS_INF = __int_as_float(0x7f800000); - if (threadIdx.x < OVO_SMALL_MAX) { - grp_smem[threadIdx.x] = - (threadIdx.x < n_grp) ? grp_col[threadIdx.x] : POS_INF; - } - __syncthreads(); - - bitonic_sort_smem(grp_smem, OVO_SMALL_MAX); - - const float* ref_col = ref_sorted + (long long)col * n_ref; - double local_sum = 0.0; - double local_tie_delta = 0.0; - - if (threadIdx.x < n_grp) { - float v = grp_smem[threadIdx.x]; - int ref_lb = 0, ref_ub = 0, grp_lb = 0, grp_ub = 0; - OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_smem, n_grp, v, ref_lb, - ref_ub, grp_lb, grp_ub); - local_sum += r.mid_rank; - - if (compute_tie_corr && - (threadIdx.x == 0 || v != grp_smem[threadIdx.x - 1])) { - double combined = (double)(r.n_eq_ref + r.n_eq_grp); - if (combined > 1.0) { - local_tie_delta += combined * combined * combined - combined; - } - if (r.n_eq_ref > 1) { - double cr = (double)r.n_eq_ref; - local_tie_delta -= cr * cr * cr - cr; - } - } - } - - double total = wilcoxon_block_sum(local_sum, warp_buf); - if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; - - if (!compute_tie_corr) return; - __syncthreads(); - - double tie_delta = wilcoxon_block_sum(local_tie_delta, warp_buf); - if (threadIdx.x == 0) - tie_corr[grp * n_cols + col] = - finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie_delta); -} - // ============================================================================ // MEDIUM-band fused kernel: no-sort direct rank for medium groups. // @@ -468,247 +433,7 @@ __global__ void ovo_rank_medium_kernel( finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie_delta); } -// ============================================================================ -// Warp-scoped tie correction for the WARP band. -// -// Sorted values live in a 32-lane register (one per lane, with unused lanes -// carrying +INF). Walks unique values via lane-step differentials and -// counts ties across the sorted ref column via binary search. All the -// sync is __syncwarp — no smem, no __syncthreads. -// ============================================================================ - -__device__ __forceinline__ double warp_tie_sum(const float* ref_col, int n_ref, - float v_lane, int n_grp, - unsigned int active_mask) { - int lane = threadIdx.x & 31; - double local_tie = 0.0; - - // Pass 1: per unique ref value, count occurrences in ref and in the - // sorted group (held in register v_lane across 32 lanes). - for (int base = 0; base < n_ref; base += 32) { - int i = base + lane; - bool in_ref_lane = (i < n_ref); - float v = in_ref_lane ? ref_col[i] : 0.0f; - bool is_first = in_ref_lane && ((i == 0) || (v != ref_col[i - 1])); - int cnt_ref = 0; - if (is_first) { - cnt_ref = sorted_upper_bound(ref_col, i + 1, n_ref, v) - i; - } - - // Count in grp: look up how many lanes hold v_lane == v. All lanes - // execute the shuffle loop; only lanes owning a unique ref value use - // the result. - int cnt_grp = 0; -#pragma unroll - for (int lane_i = 0; lane_i < OVO_WARP_MAX; ++lane_i) { - float vi = __shfl_sync(0xffffffff, v_lane, lane_i); - if (is_first && lane_i < n_grp && vi == v) ++cnt_grp; - } - - if (is_first) { - int cnt = cnt_ref + cnt_grp; - if (cnt > 1) { - double t = (double)cnt; - local_tie += t * t * t - t; - } - } - } - - // Pass 2: unique values in grp that are absent from ref. - // Walk lanes 0..n_grp-1; for each lane whose v differs from prev lane's, - // binary-search ref for v. If not present, count consecutive matching - // lanes (tie block). - if (lane < n_grp) { - float v = v_lane; - float prev_lane_v = - __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); - float v_prev = - (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF - bool first_in_grp = (lane == 0) || (v != v_prev); - bool in_ref = false; - if (first_in_grp) { - int lo = sorted_lower_bound(ref_col, 0, n_ref, v); - in_ref = (lo < n_ref) && (ref_col[lo] == v); - } - - // Count how many lanes ≥ this lane hold the same v. Keep the shuffle - // uniform across active lanes even though only unique, ref-absent - // group values consume the count. - int cnt = 0; -#pragma unroll - for (int lane_i = 0; lane_i < OVO_WARP_MAX; ++lane_i) { - int src_lane = (lane_i < n_grp) ? lane_i : 0; - float vi = __shfl_sync(active_mask, v_lane, src_lane); - if (first_in_grp && !in_ref && lane_i >= lane && lane_i < n_grp && - vi == v) { - ++cnt; - } - } - if (first_in_grp && !in_ref && cnt > 1) { - double t = (double)cnt; - local_tie += t * t * t - t; - } - } - - local_tie = warp_reduce_sum(local_tie); - return local_tie; // meaningful on lane 0. -} - -__device__ __forceinline__ double warp_tie_delta(const float* ref_col, - int n_ref, float v_lane, - int n_grp, - unsigned int active_mask) { - int lane = threadIdx.x & 31; - double local_delta = 0.0; - - if (lane < n_grp) { - float v = v_lane; - float prev_lane_v = - __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); - float v_prev = - (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF - bool first_in_grp = (lane == 0) || (v != v_prev); - - int cnt_grp = 0; -#pragma unroll - for (int lane_i = 0; lane_i < OVO_WARP_MAX; ++lane_i) { - int src_lane = (lane_i < n_grp) ? lane_i : 0; - float vi = __shfl_sync(active_mask, v_lane, src_lane); - if (lane_i < n_grp && vi == v) ++cnt_grp; - } - - if (first_in_grp) { - int ref_lb = sorted_lower_bound(ref_col, 0, n_ref, v); - int cnt_ref = - sorted_upper_bound(ref_col, ref_lb, n_ref, v) - ref_lb; - - double combined = (double)(cnt_ref + cnt_grp); - if (combined > 1.0) { - local_delta += combined * combined * combined - combined; - } - if (cnt_ref > 1) { - double cr = (double)cnt_ref; - local_delta -= cr * cr * cr - cr; - } - } - } - - local_delta = warp_reduce_sum(local_delta); - return local_delta; // meaningful on lane 0. -} - -// ============================================================================ -// WARP-band kernel: warp-per-(col, group) pair, 8 warps packed per block. -// -// Each warp independently loads ≤32 group values into registers (one per -// lane), bitonic-sorts via __shfl_xor_sync, binary-searches into sorted ref, -// and warp-reduces to lane 0. 8 pairs/block cuts block count 8× vs the -// block-per-pair LARGE band; no smem/__syncthreads lets warps run at full -// throughput independently. -// -// Grid: (n_cols, ceil(n_groups / 8)), Block: 256. -// ============================================================================ - -__global__ void ovo_rank_warp_kernel(const float* __restrict__ ref_sorted, - const float* __restrict__ grp_dense, - const int* __restrict__ grp_offsets, - const double* __restrict__ ref_tie_sums, - double* __restrict__ rank_sums, - double* __restrict__ tie_corr, int n_ref, - int n_all_grp, int n_cols, int n_groups, - bool compute_tie_corr) { - constexpr int WARPS_PER_BLOCK = 8; - int warp_id = threadIdx.x >> 5; - int lane = threadIdx.x & 31; - - int col = blockIdx.x; - int grp = blockIdx.y * WARPS_PER_BLOCK + warp_id; - if (col >= n_cols || grp >= n_groups) return; - - int g_start = grp_offsets[grp]; - int g_end = grp_offsets[grp + 1]; - int n_grp = g_end - g_start; - - // This kernel only handles groups that fit in a single warp (one value - // per lane). Larger groups are delegated to LARGE/HUGE in a co-launched - // kernel; since each group owns its own row in rank_sums/tie_corr, the - // two kernels interlace into the output without conflict. - if (n_grp > OVO_WARP_MAX) return; - - if (n_grp == 0) { - if (lane == 0) { - rank_sums[grp * n_cols + col] = 0.0; - if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; - } - return; - } - - // One value per lane, pad with +INF so sort pushes them to the end. - const float POS_INF = __int_as_float(0x7f800000); - const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; - float x = (lane < n_grp) ? grp_col[lane] : POS_INF; - unsigned int active_mask = __ballot_sync(0xffffffff, lane < n_grp); - - // Warp-shuffle bitonic sort (ascending) — 32 elements in registers. - for (int k = 1; k <= 16; k <<= 1) { - for (int j = k; j > 0; j >>= 1) { - float y = __shfl_xor_sync(0xffffffff, x, j); - bool asc = (((lane & (k << 1)) == 0)); - bool take_min = (((lane & j) == 0) == asc); - x = take_min ? fminf(x, y) : fmaxf(x, y); - } - } - - // After sort, x[lane] holds the lane-th smallest group value (lanes - // ≥ n_grp hold +INF). Binary-search each value into the sorted ref. - const float* ref_col = ref_sorted + (long long)col * n_ref; - double local_sum = 0.0; - - if (lane < n_grp) { - float v = x; - int n_lt_ref = sorted_lower_bound(ref_col, 0, n_ref, v); - int n_eq_ref = - sorted_upper_bound(ref_col, n_lt_ref, n_ref, v) - n_lt_ref; - - // In-group counts: in the sorted warp-register x, count lanes < this - // one that hold strictly less, and lanes with equal value. - int n_lt_grp = 0; - int n_eq_grp_offset = 0; // tied lanes strictly before this one - int n_eq_grp_after = 1; // count self -#pragma unroll - for (int lane_i = 0; lane_i < OVO_WARP_MAX; ++lane_i) { - if (lane_i >= n_grp) continue; - float vi = __shfl_sync(active_mask, v, lane_i); - if (lane_i < lane) { - if (vi < v) - ++n_lt_grp; - else if (vi == v) - ++n_eq_grp_offset; - } else if (lane_i > lane) { - if (vi == v) ++n_eq_grp_after; - } - } - int n_eq_grp_total = n_eq_grp_offset + n_eq_grp_after; - // Per-lane mid-rank; each tie lane gets the same value (matches LARGE - // band). - local_sum = (double)(n_lt_ref + n_lt_grp) + - ((double)(n_eq_ref + n_eq_grp_total) + 1.0) / 2.0; - } - - // Warp reduce. - local_sum = warp_reduce_sum(local_sum); - if (lane == 0) rank_sums[grp * n_cols + col] = local_sum; - - if (!compute_tie_corr) return; - - double tie_sum; - if (ref_tie_sums != nullptr) { - tie_sum = ref_tie_sums[col] + - warp_tie_delta(ref_col, n_ref, x, n_grp, active_mask); - } else { - tie_sum = warp_tie_sum(ref_col, n_ref, x, n_grp, active_mask); - } - if (lane == 0) - tie_corr[grp * n_cols + col] = - finalize_tie_corr(n_ref + n_grp, tie_sum); -} +// WARP (≤32) and SMALL (33–64) tiers were removed -- MEDIUM is now the smallest +// tier and covers all groups ≤ OVO_MEDIUM_MAX. The removed kernels (warp/small +// rank + warp tie helpers) are archived with restore steps in +// .claude/wilcoxon-warp-small-tiers-removed.md. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 18442bb3..c81a9fc8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -415,10 +415,10 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); bufs[s].grp_cub_temp = run_huge ? pool.alloc(grp_cub_temp_bytes) : nullptr; + // All tiers share the ref tie base now (LARGE/HUGE included), so + // allocate whenever correcting, not only for the small-group tiers. bufs[s].ref_tie_sums = - (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium)) - ? pool.alloc(sub_batch_cols) - : nullptr; + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].sub_tie_corr = diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 399f840c..51c35b26 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -95,16 +95,11 @@ static inline void scatter_cols_2d(double* dst, const double* src, int rows, sb_cols * sizeof(double), sb_cols * sizeof(double), rows, cudaMemcpyDeviceToDevice, stream); } -// WARP band: warp-per-(col,group) fused kernel. Each warp sorts+ranks one -// pair entirely in registers (warp-shuffle bitonic, no smem, no __syncthreads). -// Blocks pack 8 warps to amortise launch overhead. Fast route for -// perturbation-style workloads where most groups have a few dozen cells. -constexpr int OVO_WARP_MAX = 32; -// SMALL band: groups slightly larger than one warp. One compact smem sort -// block per (col, group), avoiding the heavier MEDIUM-band in-group scan. -constexpr int OVO_SMALL_MAX = 64; -// MEDIUM band: unsorted direct-rank kernel. Avoiding a full smem bitonic sort -// wins here despite the O(n^2) in-group count. +// MEDIUM band: unsorted direct-rank kernel and the SMALLEST OVO tier. Handles +// every group up to this size (the former WARP/SMALL sub-tiers were removed -- +// they added no measurable speedup on real tier-spanning data; see +// .claude/wilcoxon-warp-small-tiers-removed.md). Avoids a smem bitonic sort +// via an O(n^2) in-group count, cheap at these sizes. constexpr int OVO_MEDIUM_MAX = 512; // Max group size for the fused smem-sort rank kernel (the LARGE band). // Beyond this, fall back to the HUGE band: CUB segmented sort + rank kernel. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 7843579b..60043010 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -92,9 +92,7 @@ static void ovo_streaming_csr_impl( (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) : 0) + (run_huge ? cub_temp_bytes : 0) + - (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium) - ? (size_t)sub_batch_cols * sizeof(double) - : 0) + + (compute_tie_corr ? (size_t)sub_batch_cols * sizeof(double) : 0) + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double); size_t budget = rmm_available_device_bytes(0.8); size_t ref_reserve = @@ -130,10 +128,10 @@ static void ovo_streaming_csr_impl( bufs[s].grp_dense = pool.alloc(sub_grp_items); bufs[s].cub_temp = run_huge ? pool.alloc(cub_temp_bytes) : nullptr; + // LARGE/HUGE now share the ref tie base too: allocate whenever + // correcting. bufs[s].ref_tie_sums = - (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium)) - ? pool.alloc(sub_batch_cols) - : nullptr; + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].sub_tie_corr = @@ -317,9 +315,7 @@ static void ovo_streaming_csc_impl( (size_t)(sub_batch_cols + 1) * sizeof(int) + cub_temp_bytes + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) : 0) + - (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium) - ? (size_t)sub_batch_cols * sizeof(double) - : 0) + + (compute_tie_corr ? (size_t)sub_batch_cols * sizeof(double) : 0) + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double); size_t budget = rmm_available_device_bytes(0.8); n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); @@ -357,10 +353,10 @@ static void ovo_streaming_csc_impl( bufs[s].grp_dense = pool.alloc(sub_grp_items); bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + // LARGE/HUGE now share the ref tie base too: allocate whenever + // correcting. bufs[s].ref_tie_sums = - (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium)) - ? pool.alloc(sub_batch_cols) - : nullptr; + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].sub_tie_corr = diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index f0c120b5..2b2208ef 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -156,10 +156,10 @@ static void ovo_streaming_csc_host_impl( bufs[s].grp_dense = pool.alloc(sub_grp_items); bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + // LARGE/HUGE now share the ref tie base too: allocate whenever + // correcting. bufs[s].ref_tie_sums = - (compute_tie_corr && (t1.run_warp || t1.run_small || t1.run_medium)) - ? pool.alloc(sub_batch_cols) - : nullptr; + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].d_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].d_tie_corr = @@ -632,12 +632,11 @@ static void ovo_streaming_csr_host_impl( OvoTierPlan pack_t1 = make_ovo_tier_plan(h_grp_offsets + pack.first, K); int pack_tpb_rank = round_up_to_warp( std::min(pack_t1.max_grp_size, MAX_THREADS_PER_BLOCK)); - bool pack_has_above_t2 = pack_t1.max_grp_size > OVO_MEDIUM_MAX; - int pack_huge_skip_le = - pack_has_above_t2 ? OVO_MEDIUM_MAX : OVO_WARP_MAX; + // HUGE skips groups MEDIUM already handled (≤ OVO_MEDIUM_MAX). + int pack_huge_skip_le = OVO_MEDIUM_MAX; std::vector h_sort_group_ids; int pack_n_sort_groups = K; - if (pack_t1.above_warp && !pack_t1.run_large) { + if (pack_t1.above_medium && !pack_t1.run_large) { h_sort_group_ids = make_sort_group_ids(h_grp_offsets + pack.first, K, pack_huge_skip_le); pack_n_sort_groups = (int)h_sort_group_ids.size(); @@ -647,7 +646,7 @@ static void ovo_streaming_csr_host_impl( cudaStream_t stream = streams[s]; auto& buf = bufs[s]; - if (pack_t1.above_warp && !pack_t1.run_large) { + if (pack_t1.above_medium && !pack_t1.run_large) { cudaMemcpyAsync(buf.d_sort_group_ids, h_sort_group_ids.data(), h_sort_group_ids.size() * sizeof(int), cudaMemcpyHostToDevice, stream); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index 9cfc81bd..6f9346e9 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -33,13 +33,8 @@ __global__ void build_huge_seg_offsets_kernel( */ struct OvoTierPlan { int max_grp_size = 0; - int min_grp_size = 0; - bool run_warp = false; // any group fits in one warp (≤ OVO_WARP_MAX) - bool run_large = - false; // any group needs > WARP but fits the LARGE smem-sort band - bool above_warp = false; // at least one group exceeds OVO_WARP_MAX - bool run_small = false; // SMALL band: (OVO_WARP_MAX, OVO_SMALL_MAX] - bool run_medium = false; // MEDIUM band: (OVO_SMALL_MAX, OVO_MEDIUM_MAX] + bool run_medium = false; // MEDIUM band: any group ≤ OVO_MEDIUM_MAX + bool run_large = false; // LARGE band: (OVO_MEDIUM_MAX, OVO_LARGE_MAX] bool above_medium = false; // at least one group exceeds OVO_MEDIUM_MAX int large_padded = 0; int large_tpb = 0; @@ -49,35 +44,26 @@ struct OvoTierPlan { // Single source of truth for OVO tier dispatch (used by the dense path AND all // four sparse OVO impls, which extract ref+group rows to dense then call this). // Scans group sizes once; returns which size bands to co-launch (by max group): -// WARP (<=32): ovo_rank_warp_kernel (warp-shuffle sort, in registers) -// SMALL (<=64): ovo_rank_small_kernel (fixed 64-element smem sort) // MEDIUM (<=512): ovo_rank_medium_kernel (no sort; O(n^2) in-group count) // LARGE (<=2500): ovo_rank_large_kernel (fused smem bitonic sort) // HUGE (>2500): CUB segmented sort + ovo_rank_huge_kernel (presorted rank) -// Bands cooperate via skip_n_grp_le (a larger band skips groups a smaller one -// already handled). LARGE is device-adapted: if its smem would exceed the -// per-block limit it falls back to HUGE. +// MEDIUM is the smallest tier (the WARP/SMALL sub-tiers were removed -- no +// measurable speedup on real data; archived in +// .claude/wilcoxon-warp-small-tiers-removed.md). MEDIUM co-launches with LARGE +// or HUGE; the upper tier skips groups ≤ OVO_MEDIUM_MAX (skip_n_grp_le). LARGE +// is device-adapted: if its smem would exceed the per-block limit it falls back +// to HUGE. static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { OvoTierPlan c; - c.min_grp_size = INT_MAX; for (int g = 0; g < n_groups; g++) { int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; if (sz > c.max_grp_size) c.max_grp_size = sz; - if (sz < c.min_grp_size) c.min_grp_size = sz; - if (sz > OVO_WARP_MAX && sz <= OVO_SMALL_MAX) { - c.run_small = true; - } - if (sz > OVO_SMALL_MAX && sz <= OVO_MEDIUM_MAX) { - c.run_medium = true; - } + if (sz <= OVO_MEDIUM_MAX) c.run_medium = true; if (sz > OVO_MEDIUM_MAX) c.above_medium = true; } - if (n_groups == 0) c.min_grp_size = 0; - c.run_warp = (c.min_grp_size <= OVO_WARP_MAX); - c.above_warp = (c.max_grp_size > OVO_WARP_MAX); - // run_large: the fused smem-sort fast path (groups > WARP but ≤ LARGE). - c.run_large = c.above_warp && (c.max_grp_size <= OVO_LARGE_MAX); + // run_large: the fused smem-sort fast path for groups > MEDIUM but ≤ LARGE. + c.run_large = c.above_medium && (c.max_grp_size <= OVO_LARGE_MAX); if (c.run_large) { c.large_padded = 1; while (c.large_padded < c.max_grp_size) c.large_padded <<= 1; @@ -107,23 +93,6 @@ static std::vector make_sort_group_ids(const int* h_grp_offsets, return ids; } -// WARP kernel launcher: 8 warps × 32 threads per block, one (col, group) -// pair per warp. grid.y covers ceil(K/8) pair rows. -static inline void launch_ovo_warp(const float* ref_sorted, - const float* grp_dense, - const int* grp_offsets, - const double* ref_tie_sums, - double* rank_sums, double* tie_corr, - int n_ref, int n_all_grp, int sb_cols, int K, - bool compute_tie_corr, cudaStream_t stream) { - constexpr int WARPS_PER_BLOCK = 8; - dim3 grid(sb_cols, (K + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK); - ovo_rank_warp_kernel<<>>( - ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, - n_ref, n_all_grp, sb_cols, K, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(ovo_rank_warp_kernel); -} - static inline void launch_ref_tie_sums(const float* ref_sorted, double* ref_tie_sums, int n_ref, int sb_cols, cudaStream_t stream) { @@ -132,18 +101,6 @@ static inline void launch_ref_tie_sums(const float* ref_sorted, CUDA_CHECK_LAST_ERROR(ref_tie_sum_kernel); } -static inline void launch_ovo_small( - const float* ref_sorted, const float* grp_dense, const int* grp_offsets, - const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, - int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, - cudaStream_t stream) { - dim3 grid(sb_cols, K); - ovo_rank_small_kernel<<>>( - ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, - n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le); - CUDA_CHECK_LAST_ERROR(ovo_rank_small_kernel); -} - static inline void launch_ovo_medium( const float* ref_sorted, const float* grp_dense, const int* grp_offsets, const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, @@ -176,50 +133,54 @@ struct OvoTierScratch { // SINGLE OVO ranking engine, shared by the dense path and all four sparse OVO // impls (host/device CSC/CSR). Given a sorted reference slice and a dense group // slice for one column sub-batch, runs the size-banded dispatch from `plan` -// (see make_ovo_tier_plan): co-launch WARP/SMALL/MEDIUM for small groups, then -// LARGE (fused smem sort) OR HUGE (CUB segmented sort) for the rest. Callers -// differ only in how they produce ref_sorted / grp_dense. +// (see make_ovo_tier_plan): co-launch MEDIUM for groups ≤512, then LARGE (fused +// smem sort) OR HUGE (CUB segmented sort) for the rest. Callers differ only in +// how they produce ref_sorted / grp_dense. static inline void ovo_dispatch_tiers( const float* ref_sorted, const float* grp_dense, const int* grp_offsets, const OvoTierPlan& plan, const OvoTierScratch& sc, const int* d_sort_group_ids, int n_sort_groups, size_t grp_cub_temp_bytes, int sb_grp_items_actual, int tpb_rank, int n_ref, int n_all_grp, int sb_cols, int n_groups, bool compute_tie_corr, cudaStream_t stream) { + // No-tie fast path (tie_correct=False, the default): rank each group value + // vs the sorted reference only (U-identity), skipping the group sort and + // all tiers. grp_dense is unsorted here, which is exactly what this kernel + // wants. + if (!compute_tie_corr) { + constexpr int VS_REF_BLOCK = 256; + dim3 grid(sb_cols, n_groups); + ovo_rank_dense_vs_ref_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, sc.sub_rank_sums, n_ref, + n_all_grp, sb_cols, n_groups); + CUDA_CHECK_LAST_ERROR(ovo_rank_dense_vs_ref_kernel); + return; + } bool run_large = plan.above_medium && plan.run_large; bool run_huge = plan.above_medium && !run_large; - int skip_le = 0; - if (compute_tie_corr && - (plan.run_warp || plan.run_small || plan.run_medium)) { + // All tiers (MEDIUM/LARGE/HUGE) share the precomputed reference tie base, + // so compute it once per column whenever correcting. + if (compute_tie_corr) { launch_ref_tie_sums(ref_sorted, sc.ref_tie_sums, n_ref, sb_cols, stream); } - if (plan.run_warp) { - launch_ovo_warp(ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, - sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, - sb_cols, n_groups, compute_tie_corr, stream); - if (plan.above_warp) skip_le = OVO_WARP_MAX; - } - if (plan.run_small) { - launch_ovo_small(ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, - sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, - sb_cols, n_groups, compute_tie_corr, skip_le, stream); - if (plan.max_grp_size > OVO_SMALL_MAX) skip_le = OVO_SMALL_MAX; - } + // MEDIUM is the smallest tier: it handles every group ≤ OVO_MEDIUM_MAX + // (skip_n_grp_le = 0). LARGE/HUGE then take the groups above MEDIUM. if (plan.run_medium) { launch_ovo_medium(ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, - sb_cols, n_groups, compute_tie_corr, skip_le, stream); + sb_cols, n_groups, compute_tie_corr, /*skip=*/0, + stream); } - int upper_skip_le = plan.above_medium ? OVO_MEDIUM_MAX : skip_le; + int upper_skip_le = plan.above_medium ? OVO_MEDIUM_MAX : 0; if (plan.above_medium && run_large) { dim3 grid(sb_cols, n_groups); ovo_rank_large_kernel<<>>( - ref_sorted, grp_dense, grp_offsets, sc.sub_rank_sums, - sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, plan.large_padded, upper_skip_le); + ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, plan.large_padded, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_rank_large_kernel); } else if (run_huge) { int sb_grp_seg = @@ -238,9 +199,9 @@ static inline void ovo_dispatch_tiers( dim3 grid(sb_cols, n_groups); ovo_rank_huge_kernel<<>>( - ref_sorted, sc.grp_sorted, grp_offsets, sc.sub_rank_sums, - sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, upper_skip_le); + ref_sorted, sc.grp_sorted, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_rank_huge_kernel); } } diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index baa2c06b..a79917fd 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -720,15 +720,15 @@ def _make_sized_groups_adata(group_sizes, n_genes, seed=0): return adata -# Tier thresholds (wilcoxon_fast_common.cuh): tier0<=32, tier0_64<=64, -# tier2<=512, tier1(fused smem sort)<=2500, tier3(CUB segmented sort)>2500. -# Group sizes in the standard blobs datasets are <=~70, so tier1/tier3 are -# otherwise never exercised. These force a single large test group. +# OVO tiers (wilcoxon_fast_common.cuh): MEDIUM<=512, LARGE(fused smem sort)<=2500, +# HUGE(CUB segmented sort)>2500. Group sizes in the standard blobs datasets are +# <=~70 (all MEDIUM), so LARGE/HUGE are otherwise never exercised. These force a +# single large test group. @pytest.mark.parametrize( "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] ) @pytest.mark.parametrize("tie_correct", [False, True]) -@pytest.mark.parametrize("big", [700, 3000], ids=["tier1_fused", "tier3_cub"]) +@pytest.mark.parametrize("big", [700, 3000], ids=["large_fused", "huge_cub"]) def test_wilcoxon_ovo_large_group_tiers_match_scanpy(fmt, tie_correct, big): # g0 = reference, g1 = the large test group that drives tier selection. adata_gpu = _make_sized_groups_adata([60, big, 45], n_genes=6, seed=1) @@ -759,40 +759,6 @@ def test_wilcoxon_ovo_large_group_tiers_match_scanpy(fmt, tie_correct, big): ) -@pytest.mark.parametrize( - "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] -) -def test_wilcoxon_ovo_mixed_tier_sizes_match_scanpy(fmt): - # Groups spanning tier0 (20), tier0_64 (50) and tier2 (300) co-launched with - # tie_correct=True, pinning the skip_le boundaries and the ref_tie_sums gate. - adata_gpu = _make_sized_groups_adata([80, 20, 50, 300], n_genes=6, seed=2) - adata_cpu = adata_gpu.copy() - adata_gpu.X = _to_format(adata_gpu.X, fmt) - - kw = { - "groupby": "group", - "method": "wilcoxon", - "use_raw": False, - "reference": "g0", - "tie_correct": True, - "n_genes": 6, - } - rsc.tl.rank_genes_groups(adata_gpu, **kw) - sc.tl.rank_genes_groups(adata_cpu, **kw) - - gpu = adata_gpu.uns["rank_genes_groups"] - cpu = adata_cpu.uns["rank_genes_groups"] - for field in ("scores", "pvals"): - for group in gpu[field].dtype.names: - np.testing.assert_allclose( - np.asarray(gpu[field][group], dtype=float), - np.asarray(cpu[field][group], dtype=float), - rtol=1e-13, - atol=1e-15, - equal_nan=True, - ) - - # n_groups > ~3056 makes the per-block smem for the sparse-OVR accumulator # ((2*n_groups+32) doubles) exceed the 48KB static limit, so sparse_ovr_smem_config # (and the dense ovr_smem_config) fall back to the global-memory accumulator. @@ -1792,9 +1758,9 @@ def _anndata_with_group_sizes(sizes, *, n_genes=6, seed=0): """Dense AnnData whose per-group cell counts are exactly ``sizes``. The OVO tier dispatch picks the rank kernel by *test-group* size - (WARP<=32, SMALL 33-64, MEDIUM 65-512, LARGE 513-2500, HUGE>2500), so - engineered group sizes drive specific bands. Integer data is float32-exact, - so ranking is bit-identical to scanpy. + (MEDIUM<=512, LARGE 513-2500, HUGE>2500), so engineered group sizes drive + specific bands. Integer data is float32-exact, so ranking is bit-identical + to scanpy. """ rng = np.random.default_rng(seed) labels = [] @@ -1827,11 +1793,11 @@ def _assert_ovo_matches_scanpy(adata, reference): ) -def test_ovo_tier_bands_warp_small_medium_large_match_scanpy(): - """OVO dense-tiered path must hit the WARP/SMALL/MEDIUM/LARGE rank kernels - (test-group sizes 20/50/300/1000, all <= 2500) and match scanpy.""" +def test_ovo_tier_bands_medium_large_match_scanpy(): + """OVO dense-tiered path: small groups (20/50/300, all <= 512) run through + MEDIUM and a 1000-cell group through LARGE; must match scanpy.""" adata = _anndata_with_group_sizes( - {"ref": 40, "warp": 20, "small": 50, "medium": 300, "large": 1000}, seed=1 + {"ref": 40, "g20": 20, "g50": 50, "g300": 300, "g1000": 1000}, seed=1 ) _assert_ovo_matches_scanpy(adata, reference="ref") From a4150ccbe03ff3305dd72bad8d9e78c248026a90 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 24 Jun 2026 17:39:47 +0200 Subject: [PATCH 29/36] fix logreg order Signed-off-by: Intron7 --- .../tools/_rank_genes_groups/_logreg.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py b/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py index d4bf0dc3..1232fe28 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import cupy as cp +import numpy as np from rapids_singlecell._compat import DaskArray, _meta_dense @@ -21,7 +22,20 @@ def logreg(rg: _RankGenes, **kwds) -> list[tuple[int, NDArray, None]]: n_groups = len(rg.groups_order) selected = rg.group_codes < n_groups X = rg.X[selected, :] - grouping_logreg = rg.group_codes[selected].astype(X.dtype) + codes = rg.group_codes[selected] + + # Encode the multinomial class labels in canonical (original category) order + # rather than in `groups_order` order. groups_order echoes the user's + # `groups=` argument (see _select_groups), but cuML's softmax solver is not + # invariant to a class-index permutation, so without this the fitted scores + # would depend on the order groups are listed in. canon_label[i] is the + # class index used for groups_order[i]; coef_ rows are mapped back below. + cat_order = {str(c): i for i, c in enumerate(rg.labels.cat.categories)} + canon_key = np.array([cat_order[str(g)] for g in rg.groups_order]) + canon_label = np.empty(n_groups, dtype=np.int64) + canon_label[np.argsort(canon_key, kind="stable")] = np.arange(n_groups) + relabel = cp.asarray(canon_label) if isinstance(codes, cp.ndarray) else canon_label + grouping_logreg = relabel[codes].astype(X.dtype) if isinstance(X, DaskArray): import dask.array as da @@ -46,7 +60,8 @@ def logreg(rg: _RankGenes, **kwds) -> list[tuple[int, NDArray, None]]: if n_groups <= 2: scores = scores_all[0].get() else: - scores = scores_all[igroup].get() + # coef_ rows are in canonical class order; map back to groups_order. + scores = scores_all[int(canon_label[igroup])].get() results.append((igroup, scores, None)) From dec359345ad69823595716f271dc8db473a18b5c Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 24 Jun 2026 19:16:18 +0200 Subject: [PATCH 30/36] add 64 bit Signed-off-by: Intron7 --- .../_cuda/sparse_extract/sparse_extract.cuh | 14 +++- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 8 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 79 +++++++++++++------ .../_cuda/wilcoxon/wilcoxon_sparse.cu | 46 ++++++++--- .../tools/_rank_genes_groups/_wilcoxon.py | 41 +++++++--- 5 files changed, 135 insertions(+), 53 deletions(-) diff --git a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh index 2e4b02c8..36962b02 100644 --- a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh +++ b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh @@ -115,9 +115,9 @@ __global__ void csc_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, // CSR selected rows -> dense F-order. row_ids[tid] = source row; output column // is (col - col_start), output row is tid. Requires sorted indices (binary // search + break). Output must be pre-zeroed. -template +template __global__ void csr_extract_dense_kernel(const T* __restrict__ data, - const int* __restrict__ indices, + const IndexT* __restrict__ indices, const IndptrT* __restrict__ indptr, const int* __restrict__ row_ids, T* __restrict__ out, int n_target, @@ -191,3 +191,13 @@ __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, } } } + +// Narrowing element-wise cast (e.g. int64 row indices -> int32 sort values). +// Used only when the input index width exceeds int32; the caller guarantees the +// values fit the destination type (row/col positions < 2^31). +template +__global__ void cast_array_kernel(const SrcT* __restrict__ src, + DstT* __restrict__ dst, size_t n) { + size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) dst[i] = (DstT)src[i]; +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 60043010..c0613f6b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -8,9 +8,9 @@ * reference slice. This mirrors the fast host-CSR path and avoids redoing the * reference dense extraction + segmented sort for every column sub-batch. */ -template +template static void ovo_streaming_csr_impl( - const float* csr_data, const int* csr_indices, const IndptrT* csr_indptr, + const float* csr_data, const IndexT* csr_indices, const IndptrT* csr_indptr, const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { @@ -248,9 +248,9 @@ static void ovo_streaming_csr_impl( * Like the CSR variant, but extracts rows via lookup maps so it can operate on * native CSC input without converting the whole matrix. */ -template +template static void ovo_streaming_csc_impl( - const float* csc_data, const int* csc_indices, const IndptrT* csc_indptr, + const float* csc_data, const IndexT* csc_indices, const IndptrT* csc_indptr, const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index f66e4db8..774c4953 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -44,8 +44,8 @@ static void ovr_sparse_csc_host_streaming_impl( if (max_nnz > 0) { int max_nnz_i32 = checked_cub_items(max_nnz, "OVR host CSC sparse sub-batch nnz"); - cub_temp_bytes = cub_segmented_sortpairs_temp_bytes( - max_nnz_i32, sub_batch_cols); + cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(max_nnz_i32, sub_batch_cols); } // pool first: streams drain before it frees their scratch (see guard doc). @@ -64,9 +64,10 @@ static void ovr_sparse_csc_host_streaming_impl( InT* d_sparse_data_orig; float* d_sparse_data_f32; IndexT* d_sparse_indices; + int* idx_i32; // int32 sort-val scratch; only used when IndexT != int int* d_seg_offsets; float* keys_out; - IndexT* vals_out; + int* vals_out; uint8_t* cub_temp; double* d_rank_sums; double* d_tie_corr; @@ -79,9 +80,11 @@ static void ovr_sparse_csc_host_streaming_impl( bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].idx_i32 = + (sizeof(IndexT) > sizeof(int)) ? pool.alloc(max_nnz) : nullptr; bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].keys_out = pool.alloc(max_nnz); - bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); bufs[s].d_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); @@ -148,27 +151,42 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpyHostToDevice, stream); } + // Row indices are the sort values; downcast int64 -> int32 at the + // device boundary (values < n_rows < 2^31) so sort + rank stay int32. + int* idx32; + if constexpr (sizeof(IndexT) > sizeof(int)) { + if (batch_nnz > 0) { + int cblk = (batch_nnz + tpb - 1) / tpb; + cast_array_kernel<<>>( + buf.d_sparse_indices, buf.idx_i32, (size_t)batch_nnz); + CUDA_CHECK_LAST_ERROR(cast_array_kernel); + } + idx32 = buf.idx_i32; + } else { + idx32 = buf.d_sparse_indices; + } + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), cudaMemcpyDeviceToDevice, stream); // Cast to float32 for sort + accumulate stats in float64 - launch_ovr_cast_and_accumulate_sparse( - buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, idx32, buf.d_seg_offsets, d_group_codes, buf.d_group_sums, buf.d_group_nnz, sb_cols, n_groups, compute_nnz, tpb, smem_cast, cast_use_gmem, stream); // Sort only stored nonzeros (float32 keys) if (batch_nnz > 0) { - cub_segmented_sortpairs( - buf.cub_temp, cub_temp_bytes, buf.d_sparse_data_f32, - buf.keys_out, buf.d_sparse_indices, buf.vals_out, batch_nnz, - sb_cols, buf.d_seg_offsets, buf.d_seg_offsets + 1, stream, - "host CSC OVR segmented sort"); + cub_segmented_sortpairs(buf.cub_temp, cub_temp_bytes, + buf.d_sparse_data_f32, buf.keys_out, idx32, + buf.vals_out, batch_nnz, sb_cols, + buf.d_seg_offsets, buf.d_seg_offsets + 1, + stream, "host CSC OVR segmented sort"); } - launch_ovr_sparse_rank( + launch_ovr_sparse_rank( buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, n_rows, sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, @@ -732,9 +750,9 @@ static void ovr_sparse_csr_host_streaming_impl( // Sparse-aware CSC OVR streaming (sort only stored nonzeros) // ============================================================================ -template +template static void ovr_sparse_csc_streaming_impl( - const float* csc_data, const int* csc_indices, const IndptrT* csc_indptr, + const float* csc_data, const IndexT* csc_indices, const IndptrT* csc_indptr, const int* group_codes, const double* group_sizes, double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { @@ -787,6 +805,7 @@ static void ovr_sparse_csc_streaming_impl( struct StreamBuf { float* keys_out; int* vals_out; + int* idx_i32; // int32 sort-val scratch; only used when IndexT != int int* seg_offsets; uint8_t* cub_temp; double* sub_rank_sums; @@ -797,6 +816,8 @@ static void ovr_sparse_csc_streaming_impl( for (int s = 0; s < n_streams; s++) { bufs[s].keys_out = pool.alloc(max_nnz); bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].idx_i32 = + (sizeof(IndexT) > sizeof(int)) ? pool.alloc(max_nnz) : nullptr; bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); bufs[s].sub_rank_sums = @@ -833,13 +854,27 @@ static void ovr_sparse_csc_streaming_impl( CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); } - // Sort only stored values (keys=data, vals=row_indices) + // Sort only stored values (keys=data, vals=row_indices). Row indices + // always fit int32 (n_rows < 2^31); downcast int64 input here so the + // sort + rank stay int32 (half the val buffer) -- the device boundary. if (batch_nnz > 0) { - cub_segmented_sortpairs( - buf.cub_temp, cub_temp_bytes, csc_data + ptr_start, - buf.keys_out, csc_indices + ptr_start, buf.vals_out, batch_nnz, - sb_cols, buf.seg_offsets, buf.seg_offsets + 1, stream, - "device CSC OVR segmented sort"); + const int* idx_src; + if constexpr (sizeof(IndexT) > sizeof(int)) { + int cblk = (batch_nnz + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + cast_array_kernel + <<>>( + csc_indices + ptr_start, buf.idx_i32, + (size_t)batch_nnz); + CUDA_CHECK_LAST_ERROR(cast_array_kernel); + idx_src = buf.idx_i32; + } else { + idx_src = csc_indices + ptr_start; + } + cub_segmented_sortpairs(buf.cub_temp, cub_temp_bytes, + csc_data + ptr_start, buf.keys_out, idx_src, + buf.vals_out, batch_nnz, sb_cols, + buf.seg_offsets, buf.seg_offsets + 1, + stream, "device CSC OVR segmented sort"); } // Sparse rank kernel (handles implicit zeros analytically) @@ -879,9 +914,9 @@ static void ovr_sparse_csc_streaming_impl( * * Compared to the dense CSR path, sort work drops by ~1/sparsity. */ -template +template static void ovr_sparse_csr_streaming_impl( - const float* csr_data, const int* csr_indices, const IndptrT* csr_indptr, + const float* csr_data, const IndexT* csr_indices, const IndptrT* csr_indptr, const int* group_codes, const double* group_sizes, double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index fdb2510d..928b8af1 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -19,11 +19,11 @@ template void register_sparse_bindings(nb::module_& m) { m.doc() = "Sparse-native host Wilcoxon CUDA kernels"; -#define RSC_OVR_SPARSE_DEVICE_BINDING(NAME, IMPL, IndptrCType) \ +#define RSC_OVR_SPARSE_DEVICE_BINDING(NAME, IMPL, IndexCType, IndptrCType) \ m.def( \ NAME, \ [](gpu_array_c data, \ - gpu_array_c indices, \ + gpu_array_c indices, \ gpu_array_c indptr, \ gpu_array_c group_codes, \ gpu_array_c group_sizes, \ @@ -42,13 +42,19 @@ void register_sparse_bindings(nb::module_& m) { "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device", - ovr_sparse_csc_streaming_impl, int); + ovr_sparse_csc_streaming_impl, int, int); RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device_i64", - ovr_sparse_csc_streaming_impl, int64_t); + ovr_sparse_csc_streaming_impl, int, int64_t); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device_i64_idx64", + ovr_sparse_csc_streaming_impl, int64_t, + int64_t); RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device", - ovr_sparse_csr_streaming_impl, int); + ovr_sparse_csr_streaming_impl, int, int); RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device_i64", - ovr_sparse_csr_streaming_impl, int64_t); + ovr_sparse_csr_streaming_impl, int, int64_t); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device_i64_idx64", + ovr_sparse_csr_streaming_impl, int64_t, + int64_t); #undef RSC_OVR_SPARSE_DEVICE_BINDING #define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ @@ -85,6 +91,12 @@ void register_sparse_bindings(nb::module_& m) { int); RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, int, int64_t); + // int64 row indices (int64 indptr): pass indices natively, downcast to + // int32 per-batch on-device rather than a full host int32 copy. + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64_idx64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64_idx64", double, + int64_t, int64_t); #undef RSC_OVR_SPARSE_CSC_HOST_BINDING #define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ @@ -129,11 +141,11 @@ void register_sparse_bindings(nb::module_& m) { int64_t, int64_t); #undef RSC_OVR_SPARSE_CSR_HOST_BINDING -#define RSC_OVO_DEVICE_BINDING(NAME, IMPL, IndptrCType) \ +#define RSC_OVO_DEVICE_BINDING(NAME, IMPL, IndexCType, IndptrCType) \ m.def( \ NAME, \ [](gpu_array_c data, \ - gpu_array_c indices, \ + gpu_array_c indices, \ gpu_array_c indptr, \ gpu_array_c ref_rows, \ gpu_array_c grp_rows, \ @@ -154,13 +166,17 @@ void register_sparse_bindings(nb::module_& m) { "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device", ovo_streaming_csc_impl, - int); + int, int); RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device_i64", - ovo_streaming_csc_impl, int64_t); + ovo_streaming_csc_impl, int, int64_t); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device_i64_idx64", + ovo_streaming_csc_impl, int64_t, int64_t); RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device", ovo_streaming_csr_impl, - int); + int, int); RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device_i64", - ovo_streaming_csr_impl, int64_t); + ovo_streaming_csr_impl, int, int64_t); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device_i64_idx64", + ovo_streaming_csr_impl, int64_t, int64_t); #undef RSC_OVO_DEVICE_BINDING #define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ @@ -200,6 +216,12 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, int64_t); + // int64 row indices: read natively (extraction only, never sorted) to skip + // the full host int32 copy. + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64_idx64", float, int64_t, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64_idx64", double, + int64_t, int64_t); #undef RSC_OVO_CSC_HOST_BINDING #define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 892d48db..4a5792c2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -369,19 +369,34 @@ def _device_sparse_arrays_f32(X): raise TypeError(msg) data = X.data.astype(cp.float32, copy=False) - # Row/column indices fit int32 (cells and genes are < 2^31); indptr - # (cumulative nnz) may need int64, which the *_i64 device kernels handle. - indices = X.indices.astype(cp.int32, copy=False) - if X.indptr.dtype == cp.int64: - indptr = X.indptr + # Pass int64 indices natively to the *_idx64 kernels rather than a full nnz + # int32 copy (indices are only int64 when nnz > 2^31). int64 indices imply + # an int64 indptr, which those kernels require -- promote the (tiny) indptr + # if a hand-built matrix left it int32. Index values always fit int32; the + # CSC kernels downcast per-batch on-device where it's the sort value. + if X.indices.dtype == cp.int64: + indices = X.indices + indptr = ( + X.indptr + if X.indptr.dtype == cp.int64 + else X.indptr.astype(cp.int64, copy=False) + ) else: - indptr = X.indptr.astype(cp.int32, copy=False) + indices = X.indices.astype(cp.int32, copy=False) + indptr = ( + X.indptr + if X.indptr.dtype == cp.int64 + else X.indptr.astype(cp.int32, copy=False) + ) return data, indices, indptr -def _device_sparse_fn(module, base_name: str, indptr: cp.ndarray): - """Select the device kernel binding, using the int64-indptr variant if needed.""" - suffix = "_i64" if indptr.dtype == cp.int64 else "" +def _device_sparse_fn(module, base_name: str, indptr: cp.ndarray, indices: cp.ndarray): + """Select the device kernel binding (int64 indptr / int64 indices variants).""" + if indptr.dtype == cp.int64: + suffix = "_i64_idx64" if indices.dtype == cp.int64 else "_i64" + else: + suffix = "" return getattr(module, base_name + suffix) @@ -636,7 +651,7 @@ def _wilcoxon_vs_rest( rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) tie_corr = cp.ones(n_total_genes, dtype=cp.float64) if cpsp.isspmatrix_csc(X): - _device_sparse_fn(_wcs, "ovr_sparse_csc_device", indptr)( + _device_sparse_fn(_wcs, "ovr_sparse_csc_device", indptr, indices)( data, indices, indptr, @@ -656,7 +671,7 @@ def _wilcoxon_vs_rest( sparse_X = sparse_X.copy() sparse_X.sort_indices() data, indices, indptr = _device_sparse_arrays_f32(sparse_X) - _device_sparse_fn(_wcs, "ovr_sparse_csr_device", indptr)( + _device_sparse_fn(_wcs, "ovr_sparse_csr_device", indptr, indices)( data, indices, indptr, @@ -1049,7 +1064,7 @@ def _wilcoxon_with_reference( ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) - _device_sparse_fn(_wcs, "ovo_streaming_csc_device", indptr)( + _device_sparse_fn(_wcs, "ovo_streaming_csc_device", indptr, indices)( data, indices, indptr, @@ -1066,7 +1081,7 @@ def _wilcoxon_with_reference( sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, ) else: - _device_sparse_fn(_wcs, "ovo_streaming_csr_device", indptr)( + _device_sparse_fn(_wcs, "ovo_streaming_csr_device", indptr, indices)( data, indices, indptr, From a2c4b3a2d2c7287012bb050cbd7bb35c989f52d3 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Thu, 25 Jun 2026 12:13:23 +0200 Subject: [PATCH 31/36] update streaming Signed-off-by: Intron7 --- .../_cuda/sparse_extract/sparse_extract.cuh | 44 ++- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 221 ++++---------- .../_cuda/wilcoxon/wilcoxon.cu | 40 +-- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 287 ++++++++++-------- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 44 ++- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 285 +++++++++-------- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 88 +++--- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 55 ++-- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 190 +++++------- .../_cuda/wilcoxon/wilcoxon_sparse.cu | 16 +- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 113 +++---- .../tools/_rank_genes_groups/_wilcoxon.py | 121 +++++--- tests/test_rank_genes_groups_wilcoxon.py | 163 ++-------- 13 files changed, 740 insertions(+), 927 deletions(-) diff --git a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh index 36962b02..2496d16c 100644 --- a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh +++ b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh @@ -3,27 +3,22 @@ #include // ============================================================================ -// Shared CSR/CSC -> {compact CSC, dense} extraction kernels. -// -// Header-only templates used by the wilcoxon and rank_genes CUDA modules to -// land a gene-column window on the GPU in a column-usable layout. Two families: +// Shared CSR/CSC -> {compact CSC, dense} extraction kernels (header-only). // * compact CSC (csr_scatter_to_csc) -> sparse ranker (nnz only) // * dense F-order (csr_tile_to_dense, extract) -> dense ranker (all values) // ============================================================================ /** * Scatter CSR nonzeros into compact CSC for columns [col_start, col_stop). - * write_pos[c - col_start] is the prefix-sum offset for column c; each thread - * atomically claims a unique destination slot. + * write_pos[c - col_start] is column c's prefix-sum offset; threads atomically + * claim destination slots. * - * PRECONDITION: each row's `indices` must be sorted ascending -- the binary - * search for col_start and the `break` at col_stop depend on it; unsorted rows - * would silently drop or misplace nonzeros. Python dispatch calls - * `sort_indices()` before launching this kernel. + * PRECONDITION: each row's `indices` sorted ascending -- the binary search for + * col_start and the `break` at col_stop depend on it; unsorted rows would + * silently drop/misplace nonzeros. Python dispatch calls sort_indices() first. * - * `row_offset` is added to the local row index so a row-block rebased to a - * local [0, n_rows) range still records the correct global row id (out-of-core - * row-streaming OVR path). Defaults to 0 for full-matrix callers. + * `row_offset` rebases a local-row block to its global row id (out-of-core + * row-streaming OVR path). 0 for full-matrix callers. */ template __global__ void csr_scatter_to_csc_kernel( @@ -54,12 +49,11 @@ __global__ void csr_scatter_to_csc_kernel( } // Single-pass CSR-slice + densify: scatter column window [col_lb, col_ub) into -// a dense (n_cells, col_ub-col_lb) F-order double buffer, skipping the CSR -> -// CSC rebuild a `X[:, lb:ub].tocsc()` densify would do. +// a dense (n_cells, col_ub-col_lb) F-order double buffer. // -// `out` must be pre-zeroed; the atomicAdd also sums duplicate column indices -// (like scipy's sum_duplicates) -- bit-identical to dense materialization for -// canonical CSR. Output is always double; input dtype is templated. +// `out` must be pre-zeroed; atomicAdd sums duplicate column indices (like +// scipy's sum_duplicates) -- bit-identical to dense materialization for +// canonical CSR. Output always double; input dtype templated. template __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, const IndexT* __restrict__ indices, @@ -86,15 +80,13 @@ __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, } } -// CSC column-window [col_lb, col_ub) -> dense F-order (double), single fused -// pass. One block per column; threads stride that column's nonzeros. Writes are -// column-major coalesced and need NO atomicAdd -- canonical CSC has a unique -// (col,row) per nonzero (the wilcoxon dispatch canonicalizes/sums first). This -// is the densify-from-CSC counterpart to csr_tile_to_dense_kernel. +// CSC column-window [col_lb, col_ub) -> dense F-order (double), one block per +// column. NO atomicAdd -- canonical CSC has a unique (col,row) per nonzero (the +// wilcoxon dispatch canonicalizes/sums first). CSC counterpart to +// csr_tile_to_dense_kernel. // -// `out` must be pre-zeroed. `indptr` indexes columns; pass either full-matrix -// column pointers (with col_lb/col_ub) or a window rebased to [0, -// col_ub-col_lb). +// `out` must be pre-zeroed. `indptr` indexes columns; pass full-matrix column +// pointers (with col_lb/col_ub) or a window rebased to [0, col_ub-col_lb). template __global__ void csc_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, const IndexT* __restrict__ indices, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index 8bbf4bd2..ff6a2f1a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -5,11 +5,9 @@ #include "wilcoxon_block_reduce.cuh" #include "wilcoxon_fast_common.cuh" -// ============================================================================ -// Bitonic sort of `n` floats in shared memory, ascending. `n` must be a power -// of two; pad the tail with +INF before calling. Grid-stride, so any blockDim -// works (covers both the LARGE runtime-sized and SMALL fixed-size paths). -// ============================================================================ +// Bitonic sort of `n` floats in shared memory, ascending. `n` MUST be a power +// of two; pad the tail with +INF before calling. Grid-stride: any blockDim +// works. __device__ __forceinline__ void bitonic_sort_smem(float* s, int n) { for (int k = 2; k <= n; k <<= 1) { @@ -30,12 +28,9 @@ __device__ __forceinline__ void bitonic_sort_smem(float* s, int n) { } } -// ============================================================================ -// Sorted-array bounds over [lo, hi). lower: first index with arr[idx] >= v -// (count of elements < v). upper: first index with arr[idx] > v (count <= v). -// Pass an advanced `lo` to exploit per-thread-stride monotonicity. Work for -// both global and shared `arr`. -// ============================================================================ +// Sorted-array bounds over [lo, hi). lower: first idx with arr[idx] >= v (count +// of elements < v). upper: first idx with arr[idx] > v (count <= v). Advanced +// `lo` exploits per-thread-stride monotonicity; works on global or shared arr. __device__ __forceinline__ int sorted_lower_bound(const float* arr, int lo, int hi, float v) { @@ -62,8 +57,8 @@ __device__ __forceinline__ int sorted_upper_bound(const float* arr, int lo, } // Mid-rank of `v` in the merged (ref, grp) arrays. Advances the four -// incremental bounds (pass 0,0,0,0 for a fresh per-element search) and reports -// the per-array equal counts for tie correction. +// incremental bounds (pass 0,0,0,0 for a fresh search); reports per-array equal +// counts for tie correction. struct OvoRank { double mid_rank; int n_eq_ref; @@ -95,19 +90,12 @@ __device__ __forceinline__ OvoRank ovo_mid_rank(const float* ref, int n_ref, return r; } -// ============================================================================ -// Amortized tie correction for the LARGE/HUGE bands (group is SORTED). -// -// Adds only the group-only / ref-overlap delta on top of the precomputed -// reference base ref_tie_sums[col] (= ref_tie_sum_kernel), exactly like the -// MEDIUM band. Iterates the sorted group's UNIQUE values only -- one binary -// search into the ref per unique value -- so the reference is NOT rescanned per -// group (the cost a naive full-combined-rescan would pay). This makes tie -// correction O(n_grp_unique * log n_ref) instead of O(n_ref) per group, which -// dominates whenever the reference is large (e.g. a perturbation control or a -// big cluster). Bit-identical: same per-value (t^3 - t) terms, just -// reassociated against the shared ref base. -// ============================================================================ +// Amortized tie correction for LARGE/HUGE bands (group is SORTED). Adds only +// the group-only / ref-overlap delta on the precomputed ref base +// ref_tie_sums[col], like MEDIUM. Iterates the group's UNIQUE values only (one +// ref binary search each) so the ref is NOT rescanned per group: O(n_grp_unique +// * log n_ref) vs O(n_ref)/group. Bit-identical: same per-value (t^3 - t) +// terms, reassociated against the shared ref base. __device__ __forceinline__ void compute_tie_delta_sorted_grp( const float* ref_col, int n_ref, const float* grp_col, int n_grp, double ref_base, double* warp_buf, double* out) { @@ -136,16 +124,13 @@ __device__ __forceinline__ void compute_tie_delta_sorted_grp( *out = finalize_tie_corr(n_ref + n_grp, ref_base + tie); } -// ============================================================================ -// No-tie fast path (tie_correct=False, the default). Ranks each group value -// against the sorted REFERENCE only, via the Mann-Whitney U identity: +// No-tie fast path (tie_correct=False, default). Ranks each group value against +// the sorted REFERENCE only, via the Mann-Whitney U identity: // R_g = n_grp(n_grp+1)/2 + Σ_{g values}(#ref_below + 0.5·#ref_equal) -// The group-internal ranks collapse to the closed form, so the group block -// needs NO sort — each value does a full binary search into sorted ref. This -// eliminates the group segmented sort, ~half of dense-OVO time (see profiling). -// rank_sums are exact half-integers, so this matches the tiered path -// bit-for-bit. Grid: (n_cols, n_groups), Block: (tpb,). grp_dense is UNSORTED. -// ============================================================================ +// Group-internal ranks collapse to the closed form, so the group needs NO sort +// (each value binary-searches the sorted ref) -- skips the group segmented +// sort, ~half of dense-OVO time. rank_sums are exact half-integers => matches +// the tiered path bit-for-bit. Grid (n_cols, n_groups). grp_dense is UNSORTED. __global__ void ovo_rank_dense_vs_ref_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, @@ -178,30 +163,27 @@ __global__ void ovo_rank_dense_vs_ref_kernel( } } -// ============================================================================ -// Batched rank sums — pre-sorted (binary search, no shared memory sort). -// Used by the OVO streaming pipeline in wilcoxon_streaming.cu. -// Each thread carries lower/upper bounds across iterations, exploiting -// sorted-grp_col monotonicity within its stride. -// ============================================================================ - -__global__ void ovo_rank_huge_kernel( - const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, +// LARGE/HUGE pre-sorted rank kernel. Grid (n_cols, n_groups); each thread +// carries lower/upper bounds across its stride (sorted-grp_col monotonicity). +// SMEM_SORT=true (LARGE, groups <= OVO_LARGE_MAX): load unsorted group into +// dynamic smem (large_padded floats) + bitonic-sort. =false (HUGE): read a +// CUB-segmented-sorted group from global. Post-sort body (incremental mid-ranks +// + amortized ref-tie delta) is shared. Each group owns its rank_sums/tie_corr +// row, so size-gated co-launch (skip_n_grp_le) never aliases. +template +__global__ void ovo_rank_sorted_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_in, const int* __restrict__ grp_offsets, const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, int skip_n_grp_le /*= 0*/) { + int n_groups, bool compute_tie_corr, int large_padded, int skip_n_grp_le) { int col = blockIdx.x; int grp = blockIdx.y; if (col >= n_cols || grp >= n_groups) return; int g_start = grp_offsets[grp]; - int g_end = grp_offsets[grp + 1]; - int n_grp = g_end - g_start; - - // Size-gated dispatch (see ovo_rank_large_kernel for the contract). + int n_grp = grp_offsets[grp + 1] - g_start; if (n_grp <= skip_n_grp_le) return; - if (n_grp == 0) { if (threadIdx.x == 0) { rank_sums[grp * n_cols + col] = 0.0; @@ -211,119 +193,46 @@ __global__ void ovo_rank_huge_kernel( } const float* ref_col = ref_sorted + (long long)col * n_ref; - const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; - - // Incremental binary search bounds (advance monotonically per thread) - int ref_lb = 0, ref_ub = 0; - int grp_lb = 0, grp_ub = 0; - double local_sum = 0.0; - - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_col, n_grp, grp_col[i], - ref_lb, ref_ub, grp_lb, grp_ub); - local_sum += r.mid_rank; - } - __shared__ double warp_buf[32]; - double total = wilcoxon_block_sum(local_sum, warp_buf); - if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; - - if (!compute_tie_corr) return; - __syncthreads(); - - // grp_col is sorted (CUB segmented sort upstream): amortize the ref tie - // contribution via the precomputed base instead of rescanning the ref. - compute_tie_delta_sorted_grp(ref_col, n_ref, grp_col, n_grp, - ref_tie_sums[col], warp_buf, - &tie_corr[grp * n_cols + col]); -} - -// ============================================================================ -// LARGE-band fused kernel: smem bitonic sort + binary search rank sums -// For groups up to OVO_LARGE_MAX cells. No CUB, no global memory sort buffers. -// Grid: (n_cols, n_groups), Block: min(large_padded, 512) -// Shared memory: large_padded floats + 32 doubles (warp reduction) -// ============================================================================ - -__global__ void ovo_rank_large_kernel( - const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted - const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) - // unsorted - const int* __restrict__ grp_offsets, // (n_groups + 1,) - const double* __restrict__ ref_tie_sums, // (n_cols,) ref tie base - double* __restrict__ rank_sums, // (n_groups, n_cols) row-major - double* __restrict__ tie_corr, // (n_groups, n_cols) row-major - int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, - int large_padded, int skip_n_grp_le /*= 0*/) { - int col = blockIdx.x; - int grp = blockIdx.y; - if (col >= n_cols || grp >= n_groups) return; - - int g_start = grp_offsets[grp]; - int g_end = grp_offsets[grp + 1]; - int n_grp = g_end - g_start; - - // Size-gated dispatch: when co-launched with the WARP kernel we - // skip groups it's already handling. Each group owns its own - // rank_sums row, so the two kernels' writes never alias. - if (n_grp <= skip_n_grp_le) return; - - if (n_grp == 0) { - if (threadIdx.x == 0) { - rank_sums[grp * n_cols + col] = 0.0; - if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; - } - return; + const float* grp_col; + if constexpr (SMEM_SORT) { + extern __shared__ float grp_smem[]; + const float* src = grp_in + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = src[i]; + for (int i = n_grp + threadIdx.x; i < large_padded; i += blockDim.x) + grp_smem[i] = __int_as_float(0x7f800000); // +INF pad + __syncthreads(); + bitonic_sort_smem(grp_smem, large_padded); + grp_col = grp_smem; + } else { + (void)large_padded; + grp_col = + grp_in + (long long)col * n_all_grp + g_start; // CUB-presorted } - // Shared memory: [large_padded floats | 32 doubles for warp reduction] - extern __shared__ char smem_raw[]; - float* grp_smem = (float*)smem_raw; - double* warp_buf = (double*)(smem_raw + large_padded * sizeof(float)); - - // Load group into smem, pad with +INF - const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) - grp_smem[i] = grp_col[i]; - for (int i = n_grp + threadIdx.x; i < large_padded; i += blockDim.x) - grp_smem[i] = __int_as_float(0x7f800000); // +INF - __syncthreads(); - - // Bitonic sort in shared memory - bitonic_sort_smem(grp_smem, large_padded); - - // Binary search each sorted grp element against sorted ref; - // incremental bounds (monotonic within each thread's stride) - const float* ref_col = ref_sorted + (long long)col * n_ref; - int ref_lb = 0, ref_ub = 0; - int grp_lb = 0, grp_ub = 0; + int ref_lb = 0, ref_ub = 0, grp_lb = 0, grp_ub = 0; double local_sum = 0.0; - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_smem, n_grp, grp_smem[i], + OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_col, n_grp, grp_col[i], ref_lb, ref_ub, grp_lb, grp_ub); local_sum += r.mid_rank; } - double total = wilcoxon_block_sum(local_sum, warp_buf); if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; if (!compute_tie_corr) return; __syncthreads(); - - // grp_smem is sorted here: amortize the ref tie contribution via the - // precomputed base instead of rescanning the ref per group. - compute_tie_delta_sorted_grp(ref_col, n_ref, grp_smem, n_grp, + // grp_col is sorted: amortize the ref tie contribution via the precomputed + // base instead of rescanning the ref per group. + compute_tie_delta_sorted_grp(ref_col, n_ref, grp_col, n_grp, ref_tie_sums[col], warp_buf, &tie_corr[grp * n_cols + col]); } -// ============================================================================ -// MEDIUM-band helper: tie contribution of the sorted reference alone. -// One block per column. The medium unsorted-rank kernel uses this as a base -// and only adds group-only/overlap deltas from the unsorted group values. -// ============================================================================ - +// MEDIUM-band helper: tie contribution of the sorted reference alone (one block +// per column). The rank kernels use this base and add only group-only/overlap +// deltas from the group values. __global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, double* __restrict__ ref_tie_sums, int n_ref, int n_cols) { @@ -348,15 +257,10 @@ __global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, if (threadIdx.x == 0) ref_tie_sums[col] = total; } -// ============================================================================ -// MEDIUM-band fused kernel: no-sort direct rank for medium groups. -// -// Avoids the smem bitonic sort for groups in (skip_n_grp_le, -// max_n_grp_le]. Ranks are computed from ref binary searches plus an -// in-group scan over unsorted shared values. Tie correction starts from -// ref_tie_sums[col] and adds only group-only / ref-overlap deltas. -// ============================================================================ - +// MEDIUM-band fused kernel: no-sort direct rank for groups in (skip_n_grp_le, +// max_n_grp_le]. Ranks = ref binary searches + an in-group scan over unsorted +// shared values. Tie correction starts from ref_tie_sums[col] and adds only +// group-only / ref-overlap deltas. __global__ void ovo_rank_medium_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, const int* __restrict__ grp_offsets, @@ -433,7 +337,6 @@ __global__ void ovo_rank_medium_kernel( finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie_delta); } -// WARP (≤32) and SMALL (33–64) tiers were removed -- MEDIUM is now the smallest -// tier and covers all groups ≤ OVO_MEDIUM_MAX. The removed kernels (warp/small -// rank + warp tie helpers) are archived with restore steps in -// .claude/wilcoxon-warp-small-tiers-removed.md. +// WARP (≤32) and SMALL (33–64) tiers were removed; MEDIUM is now the smallest +// tier, covering all groups ≤ OVO_MEDIUM_MAX. Removed kernels archived with +// restore steps in .claude/wilcoxon-warp-small-tiers-removed.md. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index c81a9fc8..e7419c7b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -123,13 +123,11 @@ static void launch_ovr_rank_dense_streaming( sync_streams(streams, "dense OVR streaming rank"); } -// Host-streaming dense OVR: same multi-stream pipeline as the host-CSC path -// (pinned host, round-robin streams, per-batch async H2D overlapping the rank) -// feeding the dense sort+rank above. Both layouts read into an F-order device -// block PER SUB-BATCH (the full array is never transposed): F-order is a -// contiguous memcpy; C-order is a strided cudaMemcpy2DAsync of the sub-batch -// tile then read into F-order. Input dtype is cast to float32 keys; group sums -// (+ nnz) are accumulated in f64 from the native-dtype staging for means/pts. +// Host-streaming dense OVR: pinned-host multi-stream pipeline feeding the dense +// sort+rank above. Reads each sub-batch into an F-order device block (full +// array never transposed): F-order = contiguous memcpy, C-order = strided 2D +// copy. Keys cast to f32; group sums (+nnz) accumulated in f64 from native +// staging. template static void launch_ovr_rank_dense_host_streaming( const T* h_X, bool f_order, const int* group_codes, double* rank_sums, @@ -154,9 +152,8 @@ static void launch_ovr_rank_dense_host_streaming( size_t cub_temp_bytes = cub_segmented_sortpairs_temp_bytes(sub_items_i32, sub_batch_cols); - // Clamp the stream count to the device memory budget (like the sparse - // launchers) so a tall/wide host-dense matrix shrinks the pipeline rather - // than OOMing on unbounded per-stream sort scratch. + // Clamp stream count to device memory budget so a large matrix shrinks the + // pipeline rather than OOMing on per-stream sort scratch. size_t per_stream_bytes = sub_items * (sizeof(T) + (fast_keys ? 0 : sizeof(float)) + sizeof(float) + 2 * sizeof(int)) + @@ -171,8 +168,7 @@ static void launch_ovr_rank_dense_host_streaming( // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; - // Best-effort pin of the host array for faster async H2D; on failure (pin - // caps / non-pinnable memory) proceed unpinned rather than raising. + // Best-effort pin for faster async H2D; on failure proceed unpinned. HostRegisterGuard _pin(const_cast(h_X), (size_t)n_rows * n_cols * sizeof(T), 0, /*best_effort=*/true); @@ -228,7 +224,7 @@ static void launch_ovr_rank_dense_host_streaming( cudaStream_t stream = streams[s]; auto& buf = bufs[s]; - // H2D the column window on this stream (overlaps the prior batch rank). + // H2D the column window (overlaps the prior batch rank). if (f_order) { cudaMemcpyAsync(buf.d_stg, h_X + (size_t)col * n_rows, (size_t)sb_items * sizeof(T), @@ -244,9 +240,8 @@ static void launch_ovr_rank_dense_host_streaming( if (fast_keys) { keys_in = reinterpret_cast(buf.d_stg); } else { - // dense_block_to_f32_kernel is grid-stride, so a bounded grid - // covers any sb_items (up to INT_MAX) with no overflow in the - // launch math and enough blocks to saturate the device. + // grid-stride kernel: bounded grid covers any sb_items (<=INT_MAX) + // with no launch-math overflow. const unsigned int grid = (unsigned int)std::min( ((size_t)sb_items + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE, 65535u); @@ -266,9 +261,8 @@ static void launch_ovr_rank_dense_host_streaming( buf.vals_out, sb_items, sb_cols, buf.seg_offsets, buf.seg_offsets + 1, stream, "dense host OVR segmented sort"); - // gmem rank mode atomicAdds onto sub_rank_sums without self-zeroing, - // and the per-stream buffer is reused round-robin, so zero it first - // (the device-resident launcher does the same). + // gmem rank mode atomicAdds without self-zeroing and the buffer is + // reused round-robin, so zero it first. if (use_gmem) { cuda_check(cudaMemsetAsync( buf.sub_rank_sums, 0, @@ -294,9 +288,8 @@ static void launch_ovr_rank_dense_host_streaming( "dense host OVR tie_corr D2D copy"); } - // Group sums (+ nnz) for means/pts, in f64 from the native-dtype - // staging (matches the Aggregate path); fed to - // _fill_basic_stats_from_accumulators like the host-CSC path. + // Group sums (+nnz) for means/pts, f64 from native staging (matches + // the Aggregate path). if (compute_stats) { cudaMemsetAsync(buf.sub_group_sums, 0, (size_t)n_groups * sb_cols * sizeof(double), @@ -415,8 +408,7 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); bufs[s].grp_cub_temp = run_huge ? pool.alloc(grp_cub_temp_bytes) : nullptr; - // All tiers share the ref tie base now (LARGE/HUGE included), so - // allocate whenever correcting, not only for the small-group tiers. + // All tiers share the ref tie base, so allocate whenever correcting. bufs[s].ref_tie_sums = compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 51c35b26..72b0e1f5 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -22,10 +23,9 @@ static inline int host_worker_count() { return (int)std::min(hw ? hw : 4u, 32u); } -// Run fn(chunk, r0, r1) over a contiguous partition of [0, n); `chunk` is the -// 0-based worker index (for per-thread scratch). fn runs concurrently, so it -// must only read shared state and write disjoint output ranges (keyed by chunk -// or by [r0,r1)). Returns the number of chunks used. Serial for small n. +// Run fn(chunk, r0, r1) over a partition of [0, n); `chunk` = 0-based worker +// index. fn runs concurrently: read-only shared state, disjoint output ranges +// (keyed by chunk or [r0,r1)). Returns chunks used; serial for small n. template static inline int host_parallel_chunks(int n, F fn) { if (n <= 0) return 0; @@ -48,9 +48,8 @@ static inline int host_parallel_chunks(int n, F fn) { return used; } -// Run fn(r0, r1) over a contiguous partition of [0, n) across hardware threads -// (serial for small n). fn is invoked concurrently, so it must only read shared -// state and write disjoint output ranges. Used for host-side CSR gathers. +// Run fn(r0, r1) over a partition of [0, n) across hardware threads (serial for +// small n). Concurrent: read-only shared state, disjoint output ranges. template static inline void host_parallel_ranges(int n, F fn) { host_parallel_chunks(n, [&fn](int, int r0, int r1) { fn(r0, r1); }); @@ -67,9 +66,8 @@ constexpr int UTIL_BLOCK_SIZE = 256; // Scratch slots for warp-level reduction (one slot per warp, 32 warps max). constexpr int WARP_REDUCE_BUF = 32; -// Stream-count clamps shared by the streaming impls: never use more streams -// than there are column batches, nor more than the per-stream memory budget -// allows. +// Stream-count clamps: never use more streams than column batches, nor more +// than the per-stream memory budget allows. static inline int clamp_streams_by_cols(int n_cols, int sub_batch_cols) { int n = N_STREAMS; if (n_cols < n * sub_batch_cols) @@ -85,9 +83,9 @@ static inline int clamp_streams_by_budget(int n_streams, return n_streams; } -// Scatter a [rows, sb_cols] device sub-batch block (row-major doubles, source -// row stride sb_cols) into `dst` whose row stride is n_cols. `dst` must already -// point at the destination column offset (e.g. out + col). +// Scatter a [rows, sb_cols] device sub-batch (row-major doubles, src stride +// sb_cols) into `dst` (stride n_cols). `dst` must point at the dest column +// offset (e.g. out + col). static inline void scatter_cols_2d(double* dst, const double* src, int rows, int n_cols, int sb_cols, cudaStream_t stream) { @@ -95,22 +93,25 @@ static inline void scatter_cols_2d(double* dst, const double* src, int rows, sb_cols * sizeof(double), sb_cols * sizeof(double), rows, cudaMemcpyDeviceToDevice, stream); } -// MEDIUM band: unsorted direct-rank kernel and the SMALLEST OVO tier. Handles -// every group up to this size (the former WARP/SMALL sub-tiers were removed -- -// they added no measurable speedup on real tier-spanning data; see -// .claude/wilcoxon-warp-small-tiers-removed.md). Avoids a smem bitonic sort -// via an O(n^2) in-group count, cheap at these sizes. +// MEDIUM band cap: groups up to this size use unsorted O(n^2) in-group-count +// rank (no smem sort). Tier dispatch: make_ovo_tier_plan. constexpr int OVO_MEDIUM_MAX = 512; -// Max group size for the fused smem-sort rank kernel (the LARGE band). -// Beyond this, fall back to the HUGE band: CUB segmented sort + rank kernel. +// LARGE band cap (fused smem-sort kernel); beyond it -> HUGE (CUB segmented +// sort). constexpr int OVO_LARGE_MAX = 2500; -// Per-stream dense slab budget (float32 items). Sub-batching keeps -// (n_g × eff_sb_cols) ≤ this. 128M × 4B = 512 MB slab + same for sorted copy -// ≈ 1 GB / stream. Bigger = fewer launches; smaller = less per-stream memory. +// Per-stream dense slab budget (f32 items): 128M*4B=512MB slab + 512MB sorted +// copy ≈ 1GB/stream. Sub-batching keeps (n_g * eff_sb_cols) <= this. constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; -// Query CUB device-segmented-radix-sort scratch size with a dummy launch. -// Every Wilcoxon sort uses float keys and (for SortPairs) int values/offsets. +// Host->device staging-ring slot cap (nnz). Bounds the page-locked footprint: +// a pack's device buffer is filled in row-blocks of <= this many nonzeros, so +// the cold pin stays small instead of seconds when pack nnz is large. 32M nnz +// (128MB vals + 128MB cols/slot) is the joint sweet spot across scales: it +// crushes the whole-pack pin at 2M (~2.7x) yet stays well clear of a sharp +// large-scale slowdown seen with much smaller blocks at multi-billion nnz. +constexpr size_t STAGE_RING_NNZ_CAP = 32 * 1024 * 1024; + +// Query CUB segmented-radix-sort scratch size. Float keys, int values/offsets. static inline size_t cub_segmented_sortkeys_temp_bytes(int num_items, int num_segments) { size_t bytes = 0; @@ -137,8 +138,8 @@ static inline size_t cub_segmented_sortpairs_temp_bytes(int num_items, return bytes; } -// Launch wrappers for the queries above. begin/end offset arrays may be -// contiguous (off, off + 1) or distinct (starts, ends). +// Launch wrappers. begin/end offset arrays may be contiguous (off, off+1) or +// distinct (starts, ends). static inline void cub_segmented_sortkeys( void* d_temp, size_t temp_bytes, const float* keys_in, float* keys_out, int num_items, int num_segments, const int* begin_offsets, @@ -167,14 +168,12 @@ static inline void cub_segmented_sortpairs( // device query fails. constexpr size_t WILCOXON_FALLBACK_SMEM_PER_BLOCK = 48 * 1024; -// CRITICAL device-limit query that powers every smem/gmem and tier decision. -// Returns the per-block shared-memory limit (cached per device). Consumed by -// ovr_smem_config, sparse_ovr_smem_config, cast_accumulate_smem_config, and -// make_ovo_tier_plan to decide when accumulators/sorts no longer fit in smem -// and must fall back to global memory or CUB. DO NOT hardcode a smem value in -// place of this call -- the gmem-fallback thresholds (e.g. sparse OVR ~3056 -// groups) auto-scale with the GPU because of it; falls back to 48 KB if the -// query fails. +// CRITICAL: per-block smem limit (cached per device) powering every smem/gmem +// and tier decision (ovr_smem_config, sparse_ovr_smem_config, +// cast_accumulate_smem_config, make_ovo_tier_plan). DO NOT hardcode a smem +// value in place of this call -- gmem-fallback thresholds (e.g. sparse OVR +// ~3056 groups) auto-scale with the GPU. Falls back to 48 KB if the query +// fails. static inline size_t wilcoxon_max_smem_per_block() { int device = 0; if (cudaGetDevice(&device) != cudaSuccess) { @@ -218,9 +217,8 @@ static inline int checked_int_product(size_t a, size_t b, const char* context) { } // Precompute per-batch CSC column offsets rebased to each batch's ptr_start, -// laid out [n_batches][sub_batch_cols+1], and upload once. Returns the device -// buffer (allocated from `pool`). Avoids a per-batch H2D from a transient host -// buffer in the CSC host streaming impls. +// laid out [n_batches][sub_batch_cols+1], upload once (from `pool`). Avoids a +// per-batch H2D from a transient host buffer. template static inline int* precompute_csc_batch_offsets(const IndptrT* h_indptr, int n_cols, int sub_batch_cols, @@ -245,16 +243,14 @@ static inline int* precompute_csc_batch_offsets(const IndptrT* h_indptr, return d_all_offsets; } -// Largest per-batch nonzero count we let a column batch reach. A batch is -// sorted in a single CUB segmented call (int32 item count) and addressed with -// int offsets, so it must stay below INT_MAX with margin. +// Max per-batch nnz: a batch is sorted in one CUB segmented call (int32 item +// count) and addressed with int offsets, so it must stay below INT_MAX. constexpr size_t SAFE_BATCH_NNZ = 2000000000; // < INT_MAX -// Shrink a column sub-batch (halving) until the densest contiguous window of -// `sub_batch_cols` columns holds <= cap nonzeros, keeping every batch's nnz -// within int32 for CUB and bounding the per-stream transpose/sort scratch. -// `col_nnz(i)` returns the nonzero count of column i. Worst case returns 1 -// (a single column, whose nnz is <= n_rows). +// Halve sub_batch_cols until the densest window holds <= cap nonzeros, keeping +// every batch's nnz within int32 for CUB and bounding per-stream transpose/sort +// scratch. col_nnz(i) = nnz of column i. Worst case returns 1 (single column, +// nnz <= n_rows). template static inline int cap_sub_batch_by_nnz(int n_cols, int sub_batch_cols, size_t cap, ColNnz col_nnz) { @@ -274,10 +270,8 @@ static inline int cap_sub_batch_by_nnz(int n_cols, int sub_batch_cols, return sub_batch_cols; } -// --------------------------------------------------------------------------- -// RAII guard for cudaHostRegister. Unregisters on scope exit even when an -// exception unwinds — prevents leaked host pinning on stream-sync failures. -// --------------------------------------------------------------------------- +// RAII guard for cudaHostRegister: unregisters on scope exit (incl. exception +// unwind), preventing leaked host pinning on stream-sync failures. struct HostRegisterGuard { void* ptr = nullptr; @@ -287,11 +281,10 @@ struct HostRegisterGuard { if (p && bytes > 0) { cudaError_t err = cudaHostRegister(p, bytes, flags); if (err != cudaSuccess) { - // Already-registered memory belongs to another owner; use it - // without unregistering here. Other failures mean mapped reads - // would be unsafe, so surface them immediately -- unless the - // caller opts into best-effort pinning (the pin is only a - // transfer speedup; plain H2D still works unpinned). + // Already-registered = owned elsewhere; use it without + // unregistering. Other failures make mapped reads unsafe, so + // surface them -- unless best_effort (pin is only a speedup; + // unpinned H2D still works). if (err == cudaErrorHostMemoryAlreadyRegistered || best_effort) { cudaGetLastError(); // clear sticky error flag @@ -325,12 +318,11 @@ struct HostRegisterGuard { } }; -// RAII for CUDA streams/events: reclaim on every path (incl. exception unwind), -// fixing the leak when a throwing call skips a trailing manual destroy. The -// stream dtor SYNCHRONIZES before destroying. Convention: declare the -// RmmScratchPool BEFORE these guards so the streams (destroyed first) drain -// their in-flight kernels before the pool (destroyed last) frees the scratch -// those kernels read -- safe on the normal and exception-unwind paths alike. +// RAII for CUDA streams/events: reclaim on every path (incl. exception unwind). +// Stream dtor SYNCHRONIZES before destroying. CRITICAL ordering: declare the +// RmmScratchPool BEFORE these guards so streams (destroyed first) drain +// in-flight kernels before the pool (destroyed last) frees the scratch they +// read. struct ScopedCudaStream { cudaStream_t stream = nullptr; @@ -431,18 +423,60 @@ static inline int round_up_to_warp(int n) { return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; } -/** Fill linear segment offsets [0, stride, 2*stride, ..., n_segments*stride] - * on-device. One thread per output slot. */ +// Per-stream pinned host staging (f32 vals + int32 cols) with a per-slot event, +// so a CPU gather into slot s overlaps GPU compute: wait(s) blocks only until +// slot s's prior H2D drained, not the whole pipeline. +struct HostStagingRing { + std::vector> vals; + std::vector> cols; + std::vector pin_v, pin_c; + std::vector evt; + std::vector used; + HostStagingRing(int n_streams, size_t nnz) + : vals(n_streams), + cols(n_streams), + pin_v(n_streams), + pin_c(n_streams), + evt(n_streams, nullptr), + used(n_streams, 0) { + size_t n = nnz ? nnz : 1; + for (int s = 0; s < n_streams; s++) { + vals[s].reset(new float[n]); + cols[s].reset(new int[n]); + pin_v[s] = HostRegisterGuard(vals[s].get(), n * sizeof(float)); + pin_c[s] = HostRegisterGuard(cols[s].get(), n * sizeof(int)); + cuda_check( + cudaEventCreateWithFlags(&evt[s], cudaEventDisableTiming), + "HostStagingRing event create"); + } + } + ~HostStagingRing() { + for (cudaEvent_t e : evt) + if (e) cudaEventDestroy(e); + } + void wait(int s) { + if (used[s]) + cuda_check(cudaEventSynchronize(evt[s]), "HostStagingRing reuse"); + } + void record(int s, cudaStream_t stream) { + cuda_check(cudaEventRecord(evt[s], stream), "HostStagingRing record"); + used[s] = true; + } + HostStagingRing(const HostStagingRing&) = delete; + HostStagingRing& operator=(const HostStagingRing&) = delete; +}; + +/** Fill linear segment offsets [0, stride, ..., n_segments*stride] on-device. + */ __global__ void fill_linear_offsets_kernel(int* __restrict__ out, int n_segments, int stride) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i <= n_segments) out[i] = i * stride; } -/** Fill per-row stats codes for a pack of K groups. - * Given pack_grp_offsets (size K+1, relative to pack start), write - * stats_codes[r] = base_slot + group_idx_of_row_r for r in [0, pack_n_rows). - * Binary search within the K+1 offsets. */ +/** Per-row stats codes for a pack of K groups. From pack_grp_offsets (size K+1, + * relative to pack start), write stats_codes[r] = base_slot + group_idx(r) via + * binary search over the K+1 offsets. */ __global__ void fill_pack_stats_codes_kernel( const int* __restrict__ pack_grp_offsets, int* __restrict__ stats_codes, int K, int base_slot) { @@ -460,11 +494,9 @@ __global__ void fill_pack_stats_codes_kernel( stats_codes[r] = base_slot + lo; } -/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. - * Grid-strided: supports arbitrary `count` (no single-block thread limit). - * Templated so that 64-bit global indptrs can produce 32-bit pack-local - * indptrs (per-pack nnz always fits in int32 thanks to the memory budget). - */ +/** Rebase a slice of indptr: out[i] = indptr[col+i] - indptr[col]. Grid-strided + * (arbitrary `count`). Templated so 64-bit global indptrs produce 32-bit + * pack-local indptrs (per-pack nnz fits int32 via the memory budget). */ template __global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, IdxOut* __restrict__ out, int col, @@ -473,63 +505,72 @@ __global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); } -/** Fused gather + cast-to-float32 + stats accumulation, reading from mapped - * pinned host memory. Block-per-row; threads in the block cooperate on the - * row's nnz. Each nnz is read from host over PCIe exactly once — no - * intermediate native-dtype GPU buffer, no second GPU pass. - * - * h_data / h_indices: device-accessible pointers into mapped pinned host - * memory (cudaHostRegisterMapped). - * d_indptr_full: full-matrix indptr on device. - * d_row_ids: rows to gather (size n_target_rows). - * d_out_indptr: pre-computed compacted indptr, size n_target_rows+1 with - * out_indptr[i+1] - out_indptr[i] equal to the source row's - * nnz. - * - * Slot dispatch: - * d_stats_codes != nullptr → slot = d_stats_codes[r]; otherwise slot = - * fixed_slot (used for the Ref phase where every row maps to the same - * slot). slot ∉ [0, n_groups_stats) skips accumulation. - */ -template -__global__ void csr_gather_cast_accumulate_mapped_kernel( - const InT* __restrict__ h_data, const IndexT* __restrict__ h_indices, - const IndptrT* __restrict__ d_indptr_full, - const int* __restrict__ d_row_ids, const int* __restrict__ d_out_indptr, - const int* __restrict__ d_stats_codes, int fixed_slot, - float* __restrict__ d_out_data_f32, int* __restrict__ d_out_indices, - double* __restrict__ group_sums, double* __restrict__ group_nnz, - int n_target_rows, int n_cols, int n_groups_stats, bool compute_sums, - bool compute_nnz) { +// Threaded host gather of selected rows into compact staging (f32 vals + int32 +// cols) at disjoint per-row offsets (compact_indptr - base) -> race-free. +// No-pin alternative to the mapped gather kernel: only the compacted slice +// crosses the bus. +template +static void host_gather_rows_compact(const InT* h_data, const IndexT* h_indices, + const IndptrT* h_indptr, + const int* row_ids, + const CompactT* compact_indptr, + CompactT base, int n_target, + float* stage_vals, int* stage_cols) { + host_parallel_ranges(n_target, [&](int i0, int i1) { + for (int i = i0; i < i1; i++) { + int r = row_ids[i]; + IndptrT rs = h_indptr[r]; + int nnz = (int)(h_indptr[r + 1] - rs); + size_t ds = (size_t)(compact_indptr[i] - base); + for (int k = 0; k < nnz; k++) { + stage_vals[ds + k] = (float)h_data[rs + k]; + stage_cols[ds + k] = (int)h_indices[rs + k]; + } + } + }); +} + +// Threaded host cast-copy of a contiguous nnz slice into staging (f32 + int32). +// CSC analogue of host_gather_rows_compact: contiguous column batch, no gather. +// nnz fits int32 (batch-bounded). +template +static void host_cast_copy_slice(const InT* h_data, const IndexT* h_indices, + size_t start, int nnz, float* stage_vals, + int* stage_cols) { + host_parallel_ranges(nnz, [&](int k0, int k1) { + for (int k = k0; k < k1; k++) { + stage_vals[k] = (float)h_data[start + k]; + stage_cols[k] = (int)h_indices[start + k]; + } + }); +} + +// Per-group stats over an already-compact CSR (accumulate half of the mapped +// gather kernel, decoupled for host-staged data). slot = stats_codes[r] or +// fixed_slot; slot outside [0,n_groups_stats) is skipped. +__global__ void csr_compact_accumulate_kernel( + const float* __restrict__ d_data_f32, const int* __restrict__ d_indices, + const int* __restrict__ d_indptr, const int* __restrict__ d_stats_codes, + int fixed_slot, double* __restrict__ group_sums, + double* __restrict__ group_nnz, int n_target_rows, int n_cols, + int n_groups_stats, bool compute_sums, bool compute_nnz) { int r = blockIdx.x; if (r >= n_target_rows) return; - int src_row = d_row_ids[r]; - IndptrT rs = d_indptr_full[src_row]; - IndptrT re = d_indptr_full[src_row + 1]; - int row_nnz = (int)(re - rs); - int ds = d_out_indptr[r]; int slot = (d_stats_codes != nullptr) ? d_stats_codes[r] : fixed_slot; - bool accumulate = (slot >= 0 && slot < n_groups_stats); - for (int i = threadIdx.x; i < row_nnz; i += blockDim.x) { - InT v_in = h_data[rs + i]; - int c = (int)h_indices[rs + i]; - double v = (double)v_in; - d_out_data_f32[ds + i] = (float)v_in; - d_out_indices[ds + i] = c; - if (accumulate) { - if (compute_sums) { - atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); - } - if (compute_nnz && v != 0.0) { - atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); - } - } + if (slot < 0 || slot >= n_groups_stats) return; + int rs = d_indptr[r]; + int re = d_indptr[r + 1]; + for (int i = rs + threadIdx.x; i < re; i += blockDim.x) { + int c = d_indices[i]; + double v = (double)d_data_f32[i]; + if (compute_sums) atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); + if (compute_nnz && v != 0.0) + atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); } } -/** Fill linear segment offsets [0, stride, 2*stride, ...] on device. - * Runs on the supplied stream so it doesn't serialize multi-stream pipelines. - */ +/** Fill linear segment offsets [0, stride, ...] on the supplied stream (avoids + * serializing multi-stream pipelines). */ static inline void upload_linear_offsets(int* d_offsets, int n_segments, int stride, cudaStream_t stream) { int count = n_segments + 1; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index c0613f6b..37768d2b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -1,12 +1,9 @@ #pragma once /** - * CSR-direct OVO streaming pipeline. - * - * One C++ call does everything. Reference rows are extracted and sorted once - * across all columns, then each group sub-batch ranks against that cached - * reference slice. This mirrors the fast host-CSR path and avoids redoing the - * reference dense extraction + segmented sort for every column sub-batch. + * CSR-direct OVO streaming pipeline. Reference rows are extracted and sorted + * once across all columns; each group sub-batch ranks against that cached slice + * (mirrors the host-CSR path, avoids per-column reference re-extraction+sort). */ template static void ovo_streaming_csr_impl( @@ -16,9 +13,8 @@ static void ovo_streaming_csr_impl( int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - // Cap sub_batch_cols so the dense group slab (n_all_grp × sub_batch_cols, - // sorted in one CUB call) stays within int32. n_all_grp is a cell count, so - // it drives the cap; the reference side is chunked separately below. + // Cap sub_batch_cols so the group slab (n_all_grp × sub_batch_cols, one CUB + // sort) stays within int32; n_all_grp (cell count) drives the cap. { size_t cap = n_all_grp > 0 ? SAFE_BATCH_NNZ / (size_t)n_all_grp : (size_t)sub_batch_cols; @@ -53,8 +49,8 @@ static void ovo_streaming_csr_impl( } int ref_cache_cols = std::min(n_cols, (int)max_ref_cols); { - // Reference cache holds 2 floats/col/ref-row; size it to ~a third of - // what the joint allocator can serve (leaving room for group buffers). + // Ref cache = 2 floats/col/ref-row; size to ~1/3 of the allocator + // budget, leaving room for group buffers. size_t bytes_per_col = (size_t)n_ref * sizeof(float) * 2; size_t target_bytes = rmm_available_device_bytes(1.0 / 3.0); if (bytes_per_col > 0 && target_bytes >= bytes_per_col) { @@ -82,9 +78,8 @@ static void ovo_streaming_csr_impl( } // Clamp streams to the per-stream scratch budget (mirrors host OVO): the - // group dense slab scales with the cell count, so a fixed stream count - // would OOM at scale. The reference cache is allocated separately, so - // reserve its footprint first. + // group slab scales with cell count, so a fixed stream count would OOM at + // scale. Ref cache is allocated separately, so reserve its footprint first. { size_t per_stream = sub_grp_items * sizeof(float) + @@ -128,8 +123,7 @@ static void ovo_streaming_csr_impl( bufs[s].grp_dense = pool.alloc(sub_grp_items); bufs[s].cub_temp = run_huge ? pool.alloc(cub_temp_bytes) : nullptr; - // LARGE/HUGE now share the ref tie base too: allocate whenever - // correcting. + // LARGE/HUGE share the ref tie base: allocate whenever correcting. bufs[s].ref_tie_sums = compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = @@ -243,10 +237,8 @@ static void ovo_streaming_csr_impl( } /** - * CSC-direct OVO streaming pipeline. - * - * Like the CSR variant, but extracts rows via lookup maps so it can operate on - * native CSC input without converting the whole matrix. + * CSC-direct OVO streaming pipeline. Like the CSR variant, but extracts rows + * via lookup maps to operate on native CSC input without converting the matrix. */ template static void ovo_streaming_csc_impl( @@ -256,9 +248,8 @@ static void ovo_streaming_csc_impl( int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - // Cap sub_batch_cols so both dense slabs (n_ref × sub_batch_cols and - // n_all_grp × sub_batch_cols, each sorted in one CUB call) stay within - // int32. These row counts are cell counts, so they drive the cap. + // Cap sub_batch_cols so both slabs (n_ref× and n_all_grp× sub_batch_cols, + // each one CUB sort) stay within int32; these cell counts drive the cap. { size_t max_rows = (size_t)std::max(n_ref, n_all_grp); size_t cap = @@ -306,8 +297,8 @@ static void ovo_streaming_csc_impl( } // Clamp streams to the per-stream scratch budget (mirrors host OVO): the - // ref/group dense slabs scale with cell counts, so a fixed stream count - // would OOM at scale. + // ref/group slabs scale with cell counts, so a fixed count would OOM at + // scale. { size_t per_stream = 2 * sub_ref_items * sizeof(float) + @@ -353,8 +344,7 @@ static void ovo_streaming_csc_impl( bufs[s].grp_dense = pool.alloc(sub_grp_items); bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - // LARGE/HUGE now share the ref tie base too: allocate whenever - // correcting. + // LARGE/HUGE share the ref tie base: allocate whenever correcting. bufs[s].ref_tie_sums = compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].sub_rank_sums = diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 2b2208ef..1d5a5105 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -1,11 +1,9 @@ #pragma once /** - * Host-streaming CSC OVO pipeline. - * - * CSC arrays live on host. Only the sparse data for each sub-batch of - * columns is transferred to GPU. Row maps + group offsets are uploaded once. - * Results are written back to host per sub-batch. + * Host-streaming CSC OVO pipeline: CSC on host, only each column sub-batch is + * sent to GPU; row maps + group offsets uploaded once; results written back + * per sub-batch. */ template static void ovo_streaming_csc_host_impl( @@ -18,8 +16,7 @@ static void ovo_streaming_csc_host_impl( if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // Cap sub_batch_cols so neither the dense ref/group slabs (rows × - // sub_batch_cols, sorted in one CUB call) nor the per-column-batch nnz - // exceed int32. rows here are cell counts, so they dominate the dense cap. + // sub_batch_cols, one CUB call) nor per-batch nnz exceed int32. { size_t max_rows = (size_t)std::max(n_ref, n_all_grp); size_t dense_cap = @@ -73,9 +70,8 @@ static void ovo_streaming_csc_host_impl( if (nnz > max_nnz) max_nnz = nnz; } - // Reduce the stream count so the per-stream scratch fits the memory budget. - // The dense ref/group slabs scale with n_ref/n_all_grp (cell counts), so at - // scale a fixed N_STREAMS would exceed GPU memory and thrash/OOM. + // Clamp streams so per-stream scratch fits the budget: dense slabs scale + // with cell counts, so a fixed N_STREAMS would OOM at scale. { size_t per_stream = max_nnz * (sizeof(InT) + sizeof(float) + sizeof(IndexT)) + @@ -89,15 +85,10 @@ static void ovo_streaming_csc_host_impl( n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); } - // pool first: streams drain before it frees their scratch (see guard doc). + // pool first: streams drain before it frees their scratch (RAII order). RmmScratchPool pool; - // Pin host inputs before the streams so on an exception unwind the streams - // drain before the buffers are unregistered (mirrors the safe CSR order). - size_t total_nnz = (size_t)h_indptr[n_cols]; - HostRegisterGuard _pin_data(const_cast(h_data), - total_nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - total_nnz * sizeof(IndexT)); + // No full-matrix page-lock: each column batch is cast-copied into small + // per-stream pinned staging (f32 vals + int32 cols) and bulk-H2D'd. ScopedCudaStreams streams(n_streams, cudaStreamDefault); int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; @@ -127,9 +118,8 @@ static void ovo_streaming_csc_host_impl( } struct StreamBuf { - InT* d_sparse_data_orig; float* d_sparse_data_f32; - IndexT* d_sparse_indices; + int* d_sparse_indices; int* d_indptr; float* ref_dense; float* ref_sorted; @@ -147,17 +137,15 @@ static void ovo_streaming_csc_host_impl( }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); - bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); bufs[s].ref_dense = pool.alloc(sub_ref_items); bufs[s].ref_sorted = pool.alloc(sub_ref_items); bufs[s].grp_dense = pool.alloc(sub_grp_items); bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - // LARGE/HUGE now share the ref tie base too: allocate whenever - // correcting. + // LARGE/HUGE share the ref tie base: allocate whenever correcting. bufs[s].ref_tie_sums = compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; bufs[s].d_rank_sums = @@ -182,6 +170,9 @@ static void ovo_streaming_csc_host_impl( } } + // Per-stream pinned staging for the contiguous column-batch cast-copy. + HostStagingRing stage(n_streams, max_nnz); + int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); bool cast_use_gmem = false; @@ -202,22 +193,29 @@ static void ovo_streaming_csc_host_impl( auto stream = streams[s]; auto& buf = bufs[s]; - // H2D: sparse data for this column range (native dtype) IndptrT ptr_start = h_indptr[col]; IndptrT ptr_end = h_indptr[col + sb_cols]; size_t nnz = (size_t)(ptr_end - ptr_start); - checked_int_span(nnz, "OVO host CSC active batch nnz"); - cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, - nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, - nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); + int nnz_i = checked_int_span(nnz, "OVO host CSC active batch nnz"); + + // Cast-copy column batch into pinned staging, bulk H2D; the event lets + // the next copy overlap compute. + stage.wait(s); + host_cast_copy_slice(h_data, h_indices, (size_t)ptr_start, nnz_i, + stage.vals[s].get(), stage.cols[s].get()); + cudaMemcpyAsync(buf.d_sparse_data_f32, stage.vals[s].get(), + nnz * sizeof(float), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, stage.cols[s].get(), + nnz * sizeof(int), cudaMemcpyHostToDevice, stream); + stage.record(s, stream); int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); cudaMemcpyAsync(buf.d_indptr, src, (sb_cols + 1) * sizeof(int), cudaMemcpyDeviceToDevice, stream); - // Cast to float32 for sort + accumulate stats in float64 - launch_ovr_cast_and_accumulate_sparse( - buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + // Data already f32 on device: accumulate stats (cast is f32->f32 + // no-op). + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_f32, buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_nnz, sb_cols, n_groups_stats, compute_nnz, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); @@ -275,26 +273,11 @@ static void ovo_streaming_csc_host_impl( } /** - * Host CSR OVO pipeline — zero-copy mapped full-CSR with GPU-side row gather. - * - * Setup: pin the full host CSR with cudaHostRegisterMapped, upload the full - * indptr (small) + row_ids + pre-computed compacted indptrs. Each pack - * gathers only its rows over PCIe via a UVA kernel — the full matrix is never - * transferred to GPU. - * - * Phase 1 (Ref): fused gather + cast + stats over ref rows; segmented sort - * to d_ref_sorted (cached for the whole run). - * Phase 2 (per pack, round-robin across N_STREAMS): - * 1. rebase per-pack output indptr from the pre-uploaded global compacted - * indptr. - * 2. rebase per-pack group offsets + build per-row stats codes. - * 3. csr_gather_cast_accumulate_mapped_kernel — one PCIe pass, writes - * compacted f32 data + indices and accumulates per-group stats. - * 4. Per sub-batch: extract dense → sort → rank vs ref_sorted → scatter. + * Host CSR OVO pipeline (no full-matrix page-lock). * - * Memory: d_ref_sorted (n_ref × n_cols × 4B) + N_STREAMS pack buffers sized - * for max_pack_rows × sb_cols (dense) and max_pack_nnz (compacted CSR). - * Full CSR stays on host (pinned-mapped). + * Ref + each group pack are host-gathered into pinned staging, bulk-H2D'd, + * then extract dense -> segmented sort -> rank vs cached sorted ref -> scatter. + * Packs round-robin across N_STREAMS; per-slot events overlap gather + compute. */ template static void ovo_streaming_csr_host_impl( @@ -306,11 +289,8 @@ static void ovo_streaming_csr_host_impl( bool compute_nnz, bool compute_sums, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; - // Pre-compute compacted indptrs on host (O(n_ref + n_all_grp)). - // Use IndptrT for the global compacted indptr because the grp side can - // exceed 2^31 nnz on very large / dense matrices. Ref always fits in - // int32 since n_ref × n_cols ≪ 2B; keeping int32 there matches the - // downstream CUB segmented-sort temp sizing. + // Compacted indptrs on host. IndptrT for grp (can exceed 2^31 nnz when + // large/dense); ref stays int32 (n_ref × n_cols ≪ 2B, matches CUB temp). std::vector h_ref_indptr_compact(n_ref + 1); h_ref_indptr_compact[0] = 0; for (int i = 0; i < n_ref; i++) { @@ -330,7 +310,7 @@ static void ovo_streaming_csr_host_impl( } int ref_nnz = h_ref_indptr_compact[n_ref]; - // grp: compacted indptr over concatenated test-group rows (IndptrT). + // grp: compacted indptr over concatenated test-group rows. std::vector h_grp_indptr_compact(n_all_grp + 1); h_grp_indptr_compact[0] = 0; for (int i = 0; i < n_all_grp; i++) { @@ -361,9 +341,8 @@ static void ovo_streaming_csr_host_impl( GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; if ((size_t)target_rows > budget_cap_rows) target_rows = (int)budget_cap_rows; - // Also bound each pack's compacted nnz: it feeds int32 CUB item counts - // and int offsets, so a dense pack must stay under INT_MAX. This splits - // dense perturbation groups across more packs. + // Bound each pack's compacted nnz < INT_MAX (feeds int32 CUB item + // counts + offsets); splits dense groups across more packs. constexpr size_t SAFE_PACK_NNZ = 1500000000; // < INT_MAX, CUB-safe int cur_first = 0; @@ -426,34 +405,20 @@ static void ovo_streaming_csr_host_impl( (size_t)n_groups_stats * n_cols * sizeof(double)); } - // Pin full host data + indices as MAPPED (zero-copy accessible) - size_t full_nnz = (size_t)h_indptr[n_full_rows]; - HostRegisterGuard _pin_data(const_cast(h_data), - full_nnz * sizeof(InT), cudaHostRegisterMapped); - HostRegisterGuard _pin_indices(const_cast(h_indices), - full_nnz * sizeof(IndexT), - cudaHostRegisterMapped); - - // Get device-accessible pointers (UVA makes these equal to host ptrs on - // Linux x86-64, but the API is the safe/portable way). - InT* d_data_zc = nullptr; - IndexT* d_indices_zc = nullptr; - if (full_nnz > 0) { - cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, - const_cast(h_data), 0); - cudaError_t e2 = cudaHostGetDevicePointer( - (void**)&d_indices_zc, const_cast(h_indices), 0); - if (e1 != cudaSuccess || e2 != cudaSuccess) { - throw std::runtime_error( - std::string("cudaHostGetDevicePointer failed: ") + - cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); - } - } - - // Upload full indptr (keep native IndptrT — can exceed int32) - IndptrT* d_indptr_full = pool.alloc(n_full_rows + 1); - cudaMemcpy(d_indptr_full, h_indptr, (n_full_rows + 1) * sizeof(IndptrT), - cudaMemcpyHostToDevice); + // No full-matrix page-lock (the 280GB cudaHostRegister was ~7s/call). The + // gather reads pageable CSR and transfers only the compacted slice; just + // the small pinned staging buffers are registered. + (void)n_full_rows; + + // Pinned staging for the reference gather (compacted f32 vals + int32 + // cols). Uninitialized: the gather overwrites it, so skip a multi-GB zero. + size_t ref_stage_n = ref_nnz ? (size_t)ref_nnz : 1; + std::unique_ptr h_ref_stage_vals(new float[ref_stage_n]); + std::unique_ptr h_ref_stage_cols(new int[ref_stage_n]); + HostRegisterGuard pin_ref_vals(h_ref_stage_vals.get(), + ref_stage_n * sizeof(float)); + HostRegisterGuard pin_ref_cols(h_ref_stage_cols.get(), + ref_stage_n * sizeof(int)); // Upload row_ids + compacted indptrs + group boundaries int* d_ref_row_ids = pool.alloc(n_ref); @@ -470,11 +435,9 @@ static void ovo_streaming_csr_host_impl( cudaMemcpyHostToDevice); // Phase 1: Ref setup (scoped scratch, ref_sorted persists). - // The full-width sorted reference cache d_ref_sorted is [n_ref × n_cols], - // but it is built one COLUMN CHUNK at a time so each CUB segmented sort - // stays within int32 (n_ref × ref_chunk_cols items) and the dense extract - // scratch is bounded to a chunk instead of the whole [n_ref × n_cols] slab. - // This is what lets large references (n_ref × n_cols > INT_MAX) work. + // The [n_ref × n_cols] sorted cache is built one COLUMN CHUNK at a time so + // each CUB segmented sort stays within int32 and the extract scratch is + // chunk-bounded; this is what lets n_ref × n_cols > INT_MAX work. size_t ref_items = (size_t)n_ref * (size_t)n_cols; if (ref_items > std::numeric_limits::max() / (2 * sizeof(float))) { throw std::runtime_error( @@ -513,24 +476,36 @@ static void ovo_streaming_csr_host_impl( cudaMemcpy(d_ref_indptr, h_ref_indptr_compact.data(), (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); - // Fused gather + cast + stats for ref (fixed slot = n_test): one PCIe - // pass, no intermediate native-dtype buffer, all-column stats once. + // Host-gather ref rows into pinned staging, bulk H2D, accumulate stats. if (n_ref > 0 && ref_nnz > 0) { - csr_gather_cast_accumulate_mapped_kernel - <<>>( - d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, - d_ref_indptr, /*d_stats_codes=*/nullptr, - /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, + host_gather_rows_compact(h_data, h_indices, h_indptr, h_ref_row_ids, + h_ref_indptr_compact.data(), 0, n_ref, + h_ref_stage_vals.get(), + h_ref_stage_cols.get()); + cuda_check(cudaMemcpyAsync(d_ref_data_f32, h_ref_stage_vals.get(), + (size_t)ref_nnz * sizeof(float), + cudaMemcpyHostToDevice, ref_stream), + "OVO host CSR ref staged vals H2D"); + cuda_check(cudaMemcpyAsync(d_ref_indices, h_ref_stage_cols.get(), + (size_t)ref_nnz * sizeof(int), + cudaMemcpyHostToDevice, ref_stream), + "OVO host CSR ref staged cols H2D"); + if (compute_sums || compute_nnz) { + csr_compact_accumulate_kernel<<>>( + d_ref_data_f32, d_ref_indices, d_ref_indptr, + /*d_stats_codes=*/nullptr, /*fixed_slot=*/n_test, d_group_sums, d_group_nnz, n_ref, n_cols, n_groups_stats, compute_sums, compute_nnz); - CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + CUDA_CHECK_LAST_ERROR(csr_compact_accumulate_kernel); + } } size_t ref_cub_bytes = cub_segmented_sortkeys_temp_bytes( ref_chunk_items_i32, ref_chunk_cols); ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); - // Extract + segment-sort the reference one column chunk at a time. + // Extract + segment-sort the reference per column chunk. for (int cs = 0; cs < n_cols; cs += ref_chunk_cols) { int ce = std::min(cs + ref_chunk_cols, n_cols); int cc = ce - cs; @@ -575,6 +550,36 @@ static void ovo_streaming_csr_host_impl( cub_grp_bytes = cub_segmented_sortkeys_temp_bytes(max_sub_items_i32, max_segments); } + int max_pack_kernel_seg = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR pack segment buffer"); + + // Clamp streams to the device-memory budget (90% of free). The per-stream + // pack buffers + dense slabs dominate device use, so a fixed stream count + // OOMs at scale / on smaller GPUs. The sorted ref cache + small shared + // arrays are already allocated, so the measured free already excludes them; + // budget is what is left for the per-stream scratch. Fewer streams just + // means less gather/compute overlap, not a re-stream. + { + size_t per_stream = + 2 * max_pack_nnz * sizeof(float) // grp data + idx + + (size_t)(max_pack_rows + 1) * sizeof(int) // grp indptr + + (size_t)max_pack_rows * sizeof(int) // stats codes + + (size_t)(max_pack_K + 1) * sizeof(int) // pack grp offsets + + max_sub_items * sizeof(float) // grp dense + + 2 * (size_t)max_pack_K * max_pack_sb_cols * + sizeof(double) // rank+tie + + (size_t)max_pack_sb_cols * sizeof(double) // ref tie + + + (may_need_cub + ? max_sub_items * sizeof(float) // grp sorted + + (size_t)max_pack_K * sizeof(int) // sort ids + + 2 * (size_t)max_pack_kernel_seg * sizeof(int) // segs + + cub_grp_bytes // cub temp + : 0); + size_t budget = rmm_available_device_bytes(0.9); + n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); + } ScopedCudaStreams streams(n_streams, cudaStreamDefault); @@ -595,9 +600,6 @@ static void ovo_streaming_csr_host_impl( double* d_tie_corr; }; std::vector bufs(n_streams); - int max_pack_kernel_seg = - checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, - "OVO host CSR pack segment buffer"); for (int s = 0; s < n_streams; s++) { bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); @@ -625,6 +627,16 @@ static void ovo_streaming_csr_host_impl( } } + // Small rolling pinned staging shared across packs: each pack's device + // buffer is filled in row-blocks of <= stage_cap nnz, so the page-locked + // footprint stays small regardless of pack nnz (the whole-pack pin was the + // dominant cost at small/medium scale). Extra slots let the host gather run + // ahead of the in-flight H2Ds. + size_t stage_cap = std::min(max_pack_nnz, STAGE_RING_NNZ_CAP); + int ring_slots = n_streams + 2; + HostStagingRing stage(ring_slots, stage_cap); + int stage_slot = 0; + for (int p = 0; p < (int)packs.size(); p++) { const Pack& pack = packs[p]; int K = pack.end - pack.first; @@ -656,8 +668,8 @@ static void ovo_streaming_csr_host_impl( int pack_rows = pack.n_rows; int pack_sb = pack.sb_cols; - // Rebase pack's output indptr from pre-uploaded global compacted indptr - // (IndptrT → int32: pack nnz is bounded by GROUP_DENSE_BUDGET so fits). + // Rebase pack's output indptr (IndptrT → int32: pack nnz is bounded by + // GROUP_DENSE_BUDGET so fits). { int count = pack_rows + 1; int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; @@ -667,8 +679,7 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); } - // Build per-pack group offsets on GPU — needed for stats codes before - // the fused gather kernel can run. + // Per-pack group offsets on GPU — needed for stats codes. { int count = K + 1; int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; @@ -684,18 +695,54 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); } - // Fused gather + cast + stats for the pack: one PCIe pass (reads mapped - // host via UVA), no intermediate native-dtype buffer. + // Host-gather pack rows into the rolling staging in row-blocks (<= + // stage_cap nnz each), H2D each block into the pack's device buffer at + // its nnz offset, then accumulate stats over the full pack. if (pack.nnz > 0) { - csr_gather_cast_accumulate_mapped_kernel - <<>>( - d_data_zc, d_indices_zc, d_indptr_full, - d_grp_row_ids + row_start, buf.d_grp_indptr, - buf.d_pack_stats_codes, /*fixed_slot=*/-1, - buf.d_grp_data_f32, buf.d_grp_indices, d_group_sums, + IndptrT pack_base = h_grp_indptr_compact[row_start]; + int rb0 = 0; + while (rb0 < pack_rows) { + IndptrT blk_base = h_grp_indptr_compact[row_start + rb0]; + int rb1 = rb0 + 1; + while (rb1 < pack_rows && + (size_t)(h_grp_indptr_compact[row_start + rb1 + 1] - + blk_base) <= stage_cap) + rb1++; + size_t blk_nnz = + (size_t)(h_grp_indptr_compact[row_start + rb1] - blk_base); + size_t dev_off = (size_t)(blk_base - pack_base); + int slot = stage_slot % ring_slots; + stage_slot++; + // wait drains a prior H2D out of this slot before we overwrite + // it; the event lets the next gather overlap the in-flight H2D. + stage.wait(slot); + host_gather_rows_compact( + h_data, h_indices, h_indptr, + h_grp_row_ids + row_start + rb0, + h_grp_indptr_compact.data() + row_start + rb0, blk_base, + rb1 - rb0, stage.vals[slot].get(), stage.cols[slot].get()); + cuda_check(cudaMemcpyAsync(buf.d_grp_data_f32 + dev_off, + stage.vals[slot].get(), + blk_nnz * sizeof(float), + cudaMemcpyHostToDevice, stream), + "OVO host CSR pack staged vals H2D"); + cuda_check(cudaMemcpyAsync(buf.d_grp_indices + dev_off, + stage.cols[slot].get(), + blk_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream), + "OVO host CSR pack staged cols H2D"); + stage.record(slot, stream); + rb0 = rb1; + } + if (compute_sums || compute_nnz) { + csr_compact_accumulate_kernel<<>>( + buf.d_grp_data_f32, buf.d_grp_indices, buf.d_grp_indptr, + buf.d_pack_stats_codes, /*fixed_slot=*/-1, d_group_sums, d_group_nnz, pack_rows, n_cols, n_groups_stats, compute_sums, compute_nnz); - CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + CUDA_CHECK_LAST_ERROR(csr_compact_accumulate_kernel); + } } int col = 0; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index 6f9346e9..f9495b65 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -6,8 +6,7 @@ /** * Build CUB segmented-sort ranges for HUGE-band groups. Ranges point into the - * original dense group layout so the presorted rank kernel reads normal - * per-group positions. + * original dense group layout (normal per-group positions). */ __global__ void build_huge_seg_offsets_kernel( const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, @@ -26,10 +25,9 @@ __global__ void build_huge_seg_offsets_kernel( } /** - * Sizing knobs for LARGE-band dispatch: when the largest group fits in shared - * memory, a fused bitonic-sort + binary-search kernel handles the group per - * block; otherwise fall back to the HUGE band (CUB segmented sort + pre-sorted - * rank kernel). + * Sizing knobs for LARGE-band dispatch: largest group fits in smem -> fused + * bitonic-sort + binary-search kernel per block; else fall back to HUGE band + * (CUB segmented sort + pre-sorted rank kernel). */ struct OvoTierPlan { int max_grp_size = 0; @@ -41,18 +39,15 @@ struct OvoTierPlan { size_t large_smem = 0; }; -// Single source of truth for OVO tier dispatch (used by the dense path AND all -// four sparse OVO impls, which extract ref+group rows to dense then call this). -// Scans group sizes once; returns which size bands to co-launch (by max group): +// Single source of truth for OVO tier dispatch (dense + all four sparse OVO +// impls). Scans group sizes once; co-launches by max group size: // MEDIUM (<=512): ovo_rank_medium_kernel (no sort; O(n^2) in-group count) -// LARGE (<=2500): ovo_rank_large_kernel (fused smem bitonic sort) -// HUGE (>2500): CUB segmented sort + ovo_rank_huge_kernel (presorted rank) -// MEDIUM is the smallest tier (the WARP/SMALL sub-tiers were removed -- no -// measurable speedup on real data; archived in -// .claude/wilcoxon-warp-small-tiers-removed.md). MEDIUM co-launches with LARGE -// or HUGE; the upper tier skips groups ≤ OVO_MEDIUM_MAX (skip_n_grp_le). LARGE -// is device-adapted: if its smem would exceed the per-block limit it falls back -// to HUGE. +// LARGE (<=2500): ovo_rank_sorted_kernel (fused smem bitonic sort) +// HUGE (>2500): CUB segmented sort + ovo_rank_sorted_kernel +// MEDIUM co-launches with the upper tier, which skips groups <= OVO_MEDIUM_MAX +// (skip_n_grp_le). LARGE falls back to HUGE if smem exceeds the per-block +// limit. (WARP/SMALL sub-tiers removed -- +// .claude/wilcoxon-warp-small-tiers-removed.md.) static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { OvoTierPlan c; for (int g = 0; g < n_groups; g++) { @@ -62,18 +57,17 @@ static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { if (sz > OVO_MEDIUM_MAX) c.above_medium = true; } - // run_large: the fused smem-sort fast path for groups > MEDIUM but ≤ LARGE. + // run_large: fused smem-sort fast path for groups > MEDIUM but <= LARGE. c.run_large = c.above_medium && (c.max_grp_size <= OVO_LARGE_MAX); if (c.run_large) { c.large_padded = 1; while (c.large_padded < c.max_grp_size) c.large_padded <<= 1; c.large_tpb = std::min(c.large_padded, MAX_THREADS_PER_BLOCK); - c.large_smem = (size_t)c.large_padded * sizeof(float) + - WARP_REDUCE_BUF * sizeof(double); - // Device-adapt: if the fused-sort buffer exceeds the per-block smem - // limit, fall back to HUGE (no smem cap) instead of launching a kernel - // that would fail. Inert at the current ~16.6KB threshold; guards - // against threshold/device-limit changes. + // dynamic smem = grp_smem only; warp_buf is static in the kernel. + c.large_smem = (size_t)c.large_padded * sizeof(float); + // Device-adapt: if fused-sort buffer exceeds the per-block smem limit, + // fall back to HUGE (no smem cap). Inert at the ~16.6KB threshold; + // guards against threshold/device-limit changes. if (c.large_smem > wilcoxon_max_smem_per_block()) { c.run_large = false; } @@ -117,9 +111,8 @@ static inline void launch_ovo_medium( CUDA_CHECK_LAST_ERROR(ovo_rank_medium_kernel); } -// Per-stream scratch consumed by ovo_dispatch_tiers (one set per CUDA stream). -// grp_sorted/grp_seg_*/grp_cub_temp are only needed for the HUGE band and may -// be null otherwise. +// Per-stream scratch for ovo_dispatch_tiers (one set per CUDA stream). +// grp_sorted/grp_seg_*/grp_cub_temp are HUGE-band only; may be null otherwise. struct OvoTierScratch { double* ref_tie_sums; // [sb_cols] pre-computed reference tie sums, or null double* sub_rank_sums; // [n_groups * sb_cols] rank-sum output accumulator @@ -130,12 +123,11 @@ struct OvoTierScratch { uint8_t* grp_cub_temp; // HUGE: CUB scratch }; -// SINGLE OVO ranking engine, shared by the dense path and all four sparse OVO -// impls (host/device CSC/CSR). Given a sorted reference slice and a dense group -// slice for one column sub-batch, runs the size-banded dispatch from `plan` -// (see make_ovo_tier_plan): co-launch MEDIUM for groups ≤512, then LARGE (fused -// smem sort) OR HUGE (CUB segmented sort) for the rest. Callers differ only in -// how they produce ref_sorted / grp_dense. +// SINGLE OVO ranking engine, shared by dense + all four sparse OVO impls +// (host/device CSC/CSR). Given a sorted reference slice and a dense group slice +// for one column sub-batch, runs the size-banded dispatch from `plan` (see +// make_ovo_tier_plan). Callers differ only in how they produce ref_sorted / +// grp_dense. static inline void ovo_dispatch_tiers( const float* ref_sorted, const float* grp_dense, const int* grp_offsets, const OvoTierPlan& plan, const OvoTierScratch& sc, @@ -143,9 +135,8 @@ static inline void ovo_dispatch_tiers( int sb_grp_items_actual, int tpb_rank, int n_ref, int n_all_grp, int sb_cols, int n_groups, bool compute_tie_corr, cudaStream_t stream) { // No-tie fast path (tie_correct=False, the default): rank each group value - // vs the sorted reference only (U-identity), skipping the group sort and - // all tiers. grp_dense is unsorted here, which is exactly what this kernel - // wants. + // vs the sorted reference only (U-identity), skipping group sort + all + // tiers. grp_dense is unsorted here, which this kernel wants. if (!compute_tie_corr) { constexpr int VS_REF_BLOCK = 256; dim3 grid(sb_cols, n_groups); @@ -158,14 +149,13 @@ static inline void ovo_dispatch_tiers( bool run_large = plan.above_medium && plan.run_large; bool run_huge = plan.above_medium && !run_large; - // All tiers (MEDIUM/LARGE/HUGE) share the precomputed reference tie base, - // so compute it once per column whenever correcting. + // All tiers share the precomputed reference tie base; compute once/column. if (compute_tie_corr) { launch_ref_tie_sums(ref_sorted, sc.ref_tie_sums, n_ref, sb_cols, stream); } - // MEDIUM is the smallest tier: it handles every group ≤ OVO_MEDIUM_MAX - // (skip_n_grp_le = 0). LARGE/HUGE then take the groups above MEDIUM. + // MEDIUM handles every group <= OVO_MEDIUM_MAX (skip_n_grp_le = 0); + // LARGE/HUGE take the groups above MEDIUM. if (plan.run_medium) { launch_ovo_medium(ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, @@ -176,12 +166,12 @@ static inline void ovo_dispatch_tiers( int upper_skip_le = plan.above_medium ? OVO_MEDIUM_MAX : 0; if (plan.above_medium && run_large) { dim3 grid(sb_cols, n_groups); - ovo_rank_large_kernel<<>>( - ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, - sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr, plan.large_padded, upper_skip_le); - CUDA_CHECK_LAST_ERROR(ovo_rank_large_kernel); + ovo_rank_sorted_kernel + <<>>( + ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, plan.large_padded, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_rank_sorted_kernel); } else if (run_huge) { int sb_grp_seg = checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, @@ -198,10 +188,10 @@ static inline void ovo_dispatch_tiers( "OVO huge-tier group segmented sort"); dim3 grid(sb_cols, n_groups); - ovo_rank_huge_kernel<<>>( + ovo_rank_sorted_kernel<<>>( ref_sorted, sc.grp_sorted, grp_offsets, sc.ref_tie_sums, sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr, upper_skip_le); - CUDA_CHECK_LAST_ERROR(ovo_rank_huge_kernel); + n_groups, compute_tie_corr, /*large_padded=*/0, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_rank_sorted_kernel); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index 0973757e..3b77d9be 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -19,13 +19,11 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, } // CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). -// -// Decide smem-vs-gmem for the DENSE OVR rank kernel. Per-block accumulator is -// (n_groups + 32) doubles; when that exceeds the per-block smem limit (~48 KB) -// it must fall back to a global-memory accumulator (use_gmem=true), flipping at -// roughly n_groups > 6112. Not dead: smem mode with an oversized request fails -// to launch. Limit is device-queried via wilcoxon_max_smem_per_block(), so it -// auto-scales. +// smem-vs-gmem for the DENSE OVR rank kernel. Per-block accumulator is +// (n_groups+32) doubles; over the per-block smem limit (~48 KB) it falls back +// to gmem (use_gmem=true), flipping at roughly n_groups > 6112. Not dead: smem +// mode with an oversized request fails to launch. Limit device-queried via +// wilcoxon_max_smem_per_block(), so it auto-scales. static size_t ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(n_groups + 32) * sizeof(double); if (need <= wilcoxon_max_smem_per_block()) { @@ -38,17 +36,14 @@ static size_t ovr_smem_config(int n_groups, bool& use_gmem) { } /** - * CRITICAL — DO NOT REMOVE the gmem branch. This is the load-bearing path for - * Perturb-seq / pooled-CRISPR DE, where n_groups is in the thousands. - * - * Decide smem-vs-gmem for the sparse OVR rank kernel. Per-block accumulator is - * (2*n_groups + 32) doubles (grp_sums + grp_nz_count + warp buf); when that - * exceeds the per-block smem limit (~48 KB) the kernel CANNOT launch in smem - * mode, so use_gmem=true routes accumulation to a caller-provided gmem buffer. - * Flips at roughly n_groups > 3056. Reviewers/static analysis have twice - * mistaken this fallback for dead code; it is the ONLY path that works at large - * n_groups. Limit is device-queried via wilcoxon_max_smem_per_block(), so the - * threshold auto-scales. + * CRITICAL — DO NOT REMOVE the gmem branch. Load-bearing path for Perturb-seq / + * pooled-CRISPR DE (n_groups in the thousands). smem-vs-gmem for the sparse OVR + * rank kernel. Per-block accumulator is (2*n_groups+32) doubles (grp_sums + + * grp_nz_count + warp buf); over the per-block smem limit (~48 KB) the kernel + * CANNOT launch in smem mode, so use_gmem=true routes to a caller gmem buffer. + * Flips at roughly n_groups > 3056. Twice mistaken for dead code; it is the + * ONLY path that works at large n_groups. Limit device-queried via + * wilcoxon_max_smem_per_block(), so the threshold auto-scales. */ static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); @@ -75,14 +70,11 @@ __global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, } /** - * Read one transferred dense column-batch (native dtype `T`) into float32 in - * F-order (column-major), the layout the segmented sort expects. Operates on a - * single sub-batch (n_rows x sb_cols) only -- the full array is never - * reordered/transposed. - * f_order=true : staging is already F-order -> identity cast. - * f_order=false: staging is C-order (n_rows x sb_cols, row-major); each - * element is read into its F-order slot while casting. - * Grid-stride over n_rows*sb_cols elements. + * Read one dense column-batch (native `T`) into f32 F-order (the layout the + * segmented sort expects); single sub-batch only, full array never transposed. + * f_order=true : staging already F-order -> identity cast. + * f_order=false: staging C-order; read into the F-order slot while casting. + * Grid-stride over n_rows*sb_cols. */ template __global__ void dense_block_to_f32_kernel(const T* __restrict__ stg, @@ -103,12 +95,11 @@ __global__ void dense_block_to_f32_kernel(const T* __restrict__ stg, } /** - * Accumulate per-(group, column) sums (+ optional nnz) from a transferred dense - * column-batch, reading the NATIVE dtype staging in f64 so means match the - * Aggregate path (the f32 cast is only for ranking). One block per column. - * `group_sums`/`group_nnz` are this batch's (n_groups x sb_cols) buffers and - * must be pre-zeroed. Mirrors the sparse cast+accumulate the CSC host path - * runs. + * Accumulate per-(group, column) sums (+optional nnz) from a dense + * column-batch, reading NATIVE staging in f64 so means match the Aggregate path + * (the f32 cast is only for ranking). One block per column; + * group_sums/group_nnz are this batch's (n_groups x sb_cols) buffers and must + * be pre-zeroed. */ template __global__ void dense_group_accumulate_kernel( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 774c4953..16788103 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -2,10 +2,8 @@ /** * Sparse-aware host-streaming CSC OVR pipeline. - * - * Like ovr_streaming_csc_host_impl but sorts only stored nonzeros per column - * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead - * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. + * Sorts only stored nonzeros per column: GPU mem O(max_batch_nnz), sort work + * O(nnz) not O(n_rows). */ template static void ovr_sparse_csc_host_streaming_impl( @@ -16,9 +14,8 @@ static void ovr_sparse_csc_host_streaming_impl( int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; - // Bound each column batch's nnz so CUB item counts stay within int32 and - // the per-stream sort buffers fit the memory budget (column counts come - // free from the CSC indptr). + // Bound each batch's nnz: CUB item counts stay within int32 + per-stream + // sort buffers fit the budget (column counts free from CSC indptr). { constexpr size_t BYTES_PER_NNZ = sizeof(InT) + 2 * sizeof(float) + 2 * sizeof(IndexT) + 8; @@ -50,8 +47,8 @@ static void ovr_sparse_csc_host_streaming_impl( // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; - // Pin host inputs before the streams so on an exception unwind the streams - // drain before the buffers are unregistered (mirrors the safe CSR order). + // Pin host inputs before streams: on exception unwind streams drain before + // buffers are unregistered (mirrors safe CSR order). size_t total_nnz = (size_t)h_indptr[n_cols]; HostRegisterGuard _pin_data(const_cast(h_data), total_nnz * sizeof(InT)); @@ -101,8 +98,7 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), cudaMemcpyHostToDevice); - // Pre-compute rebased per-batch offsets and upload once (avoids per-batch - // H2D copy from a transient host buffer). + // Pre-compute rebased per-batch offsets, upload once (no per-batch H2D). int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; int* d_all_offsets = precompute_csc_batch_offsets( h_indptr, n_cols, sub_batch_cols, n_batches, pool, @@ -115,8 +111,8 @@ static void ovr_sparse_csc_host_streaming_impl( size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); - // In gmem mode the sparse rank kernel accumulates into rank_sums directly - // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). + // gmem mode: rank kernel accumulates into rank_sums directly, needs a + // per-stream nz_count scratch buffer sized (n_groups, sb_cols). for (int s = 0; s < n_streams; s++) { if (rank_use_gmem) { bufs[s].d_nz_scratch = @@ -141,7 +137,7 @@ static void ovr_sparse_csc_host_streaming_impl( int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), "OVR host CSC active batch nnz"); - // H2D: this column range's sparse data (native dtype) + // H2D: this column range's sparse data (native dtype). if (batch_nnz > 0) { cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, (size_t)batch_nnz * sizeof(InT), @@ -220,16 +216,12 @@ static void ovr_sparse_csc_host_streaming_impl( /** * Out-of-core OVR for a host CSR too large to stage on the GPU. * - * Column indices are sorted within each row, so a per-row cursor (init 0) lets - * us walk the matrix ONCE: for each ascending column batch [col, col_end) every - * row resumes where the previous batch stopped and emits its run of columns in - * range. The cursor advances monotonically, so each nonzero is read on the host - * and bulk-transferred exactly once over the matrix's lifetime -- the gathered, - * already-compacted per-batch slice is the only thing that crosses the bus (a - * true 1x transfer, versus re-streaming the whole CSR per batch). The column - * histogram is counted on the host; the full CSR is never page-locked (the - * gather reads it on the CPU). Single stream: the per-batch CSC accumulator - * plus the gather buffers fill the device. + * PRECONDITION: column indices sorted within each row. A per-row cursor (init + * 0) walks the matrix ONCE: for each ascending column batch [col, col_end) + * every row resumes where the prior batch stopped. Cursor advances + * monotonically, so each nonzero is read + bulk-transferred exactly once (true + * 1x transfer, not per-batch whole-CSR re-streaming). Histogram counted on + * host; full CSR never page-locked (gather reads it on CPU). Single stream. */ template static void ovr_sparse_csr_host_rowstream_impl( @@ -245,8 +237,8 @@ static void ovr_sparse_csr_host_rowstream_impl( int tpb = UTIL_BLOCK_SIZE; size_t budget = rmm_available_device_bytes(0.8); - // ---- Phase 0: column histogram on the host, threaded by row range. Each - // worker counts into a private array (no false sharing), merged after. ---- + // ---- Phase 0: host column histogram, threaded by row range; each worker + // counts into a private array (no false sharing), merged after. ---- std::vector h_col_counts(n_cols, 0); { int n_workers = host_worker_count(); @@ -261,9 +253,9 @@ static void ovr_sparse_csr_host_rowstream_impl( for (int c = 0; c < n_cols; c++) h_col_counts[c] += local[w][c]; } - // ---- Column batch size: int32 CUB limit + device buffers that fit the - // budget. Per-nnz device footprint: gathered mini-CSR (val + col) + CSC - // accumulator (val + f32 + row) + sort outputs (key + row) + CUB temp. ---- + // ---- Column batch size: int32 CUB limit + device buffers fit budget. + // Per-nnz: gather mini-CSR (val+col) + CSC accum (val+f32+row) + sort out + // (key+row) + CUB temp. ---- constexpr size_t BYTES_PER_NNZ = 2 * sizeof(InT) // gather val + csc val + 2 * sizeof(float) // f32 key in + out + 3 * sizeof(int) // gather col + 2 rows @@ -303,9 +295,9 @@ static void ovr_sparse_csr_host_rowstream_impl( size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); - // ---- Host gather staging (pinned for fast bulk H2D) + per-row cursor. - // Full CSR is NOT page-locked: gather reads it on the CPU, only the - // compacted per-batch slice crosses the bus. ---- + // ---- Host gather staging (pinned for bulk H2D) + per-row cursor. Full CSR + // NOT page-locked: gather reads it on CPU, only compacted slice crosses + // bus. size_t stage_nnz = max_batch_nnz ? max_batch_nnz : 1; std::vector h_gather_vals(stage_nnz); std::vector h_gather_cols(stage_nnz); @@ -344,19 +336,19 @@ static void ovr_sparse_csr_host_rowstream_impl( rank_use_gmem ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; - // ---- One linear pass, column-batched. The per-row cursor advances - // monotonically (indices sorted + batches ascending), so each nonzero is - // read and transferred exactly once -- no whole-matrix re-streaming. Gather - // is threaded: count each row's run, prefix-sum to per-row output offsets, - // copy rows in parallel into disjoint staging ranges. ---- + // ---- One linear column-batched pass. Cursor advances monotonically + // (sorted indices + ascending batches): each nonzero read/transferred once, + // no whole-matrix re-streaming. Threaded gather: count each row's run, + // prefix-sum to per-row offsets, copy rows into disjoint staging ranges. + // ---- std::vector g_count(n_rows); int col = 0; for (int b = 0; b < n_batches; b++) { int sb_cols = std::min(sub_batch_cols, n_cols - col); int col_end = col + sb_cols; - // Count this batch's run per row: sorted indices -> binary search from - // the cursor for the first column >= col_end. + // Per-row run for this batch: binary-search sorted indices from cursor + // for first column >= col_end. host_parallel_ranges(n_rows, [&](int r0, int r1) { for (int r = r0; r < r1; r++) { const IndexT* lo = h_indices + h_indptr[r] + cursor[r]; @@ -365,15 +357,15 @@ static void ovr_sparse_csr_host_rowstream_impl( (int)(std::lower_bound(lo, hi, (IndexT)col_end) - lo); } }); - // Prefix sum -> per-row output offsets (gathered mini-CSR row pointer). + // Prefix sum -> per-row output offsets (gather mini-CSR row pointer). h_gather_indptr[0] = 0; for (int r = 0; r < n_rows; r++) h_gather_indptr[r + 1] = checked_int_span( (size_t)h_gather_indptr[r] + (size_t)g_count[r], "rowstream gather nnz"); int batch_nnz = h_gather_indptr[n_rows]; - // Copy each row's run into its slot and advance its cursor (disjoint - // output ranges -> race-free). + // Copy each row's run into its slot, advance cursor (disjoint outputs + // -> race-free). host_parallel_ranges(n_rows, [&](int r0, int r1) { for (int r = r0; r < r1; r++) { IndptrT base = h_indptr[r] + cursor[r]; @@ -451,11 +443,9 @@ static void ovr_sparse_csr_host_rowstream_impl( /** * Host CSR variant of the sparse OVR stream. - * - * The CSR input stays in host memory. We count columns once on the CPU, then - * use mapped pinned CSR arrays for bounded per-column-batch CSR->CSC scatter - * on the GPU. This avoids both a full host->device sparse upload and any - * whole-matrix CSR->CSC conversion. + * CSR stays in host memory; columns counted once, then mapped pinned arrays + * feed bounded per-column-batch CSR->CSC scatter on the GPU -- avoids both a + * full sparse upload and any whole-matrix CSR->CSC conversion. */ template static void ovr_sparse_csr_host_streaming_impl( @@ -466,9 +456,8 @@ static void ovr_sparse_csr_host_streaming_impl( int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; - // Declared before the pool/streams so on exception unwind the streams - // drain (kernels finish reading any mapped host memory) before it is - // unregistered. + // Declared before pool/streams: on exception unwind streams drain (kernels + // finish reading mapped host memory) before unregistration. HostRegisterGuard pin_data; HostRegisterGuard pin_indices; @@ -481,9 +470,8 @@ static void ovr_sparse_csr_host_streaming_impl( size_t data_bytes = total_nnz * sizeof(InT); size_t idx_bytes = total_nnz * sizeof(IndexT); - // When the matrix is too large to stage on the device, the per-batch - // scatter would fall back to bus-latency-bound zero-copy reads. Page the - // CSR through the GPU in row blocks (pinned bulk H2D) instead. + // Too large to stage on device: per-batch scatter would fall back to + // bus-latency-bound zero-copy reads. Page the CSR through in row blocks. if (total_nnz > 0 && data_bytes + idx_bytes > (budget * 3) / 4) { ovr_sparse_csr_host_rowstream_impl( h_data, h_indices, h_indptr, h_group_codes, h_group_sizes, @@ -496,11 +484,10 @@ static void ovr_sparse_csr_host_streaming_impl( cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); - // Stage the indices on the device when they fit, so the per-column - // histogram and the per-batch CSR->CSC scatter read them at HBM speed - // rather than over the bus. Indices are needed by both, so they are staged - // first; the (equally sized) data array is staged later only if it too - // fits. A bulk pageable copy is driver-staged -- no host registration. + // Stage indices on device when they fit so histogram + scatter read at HBM + // speed not over the bus. Both need indices, so staged first; data (equal + // size) staged later only if it fits too. Bulk pageable copy is + // driver-staged -- no host registration. IndexT* d_indices = nullptr; bool indices_staged = total_nnz > 0 && idx_bytes <= budget / 2; if (total_nnz > 0) { @@ -520,9 +507,8 @@ static void ovr_sparse_csr_host_streaming_impl( } // ---- Phase 0: per-column nnz counts on the GPU ---- - // CSR has no column structure, so counting on the CPU is a serial pass over - // every nonzero. Histogram the device-accessible indices instead; only the - // n_cols counts come back for the per-batch prefix sums. + // CSR has no column structure -> CPU count is a serial pass over every nnz. + // Histogram device-accessible indices; only n_cols counts come back. std::vector h_col_counts(n_cols, 0); if (total_nnz > 0) { unsigned int* d_col_counts = pool.alloc(n_cols); @@ -537,11 +523,10 @@ static void ovr_sparse_csr_host_streaming_impl( "OVR host CSR column-count D2H"); } - // Each column batch is sorted in one CUB segmented call (int32 item count) - // and its CSR->CSC transpose lives in per-stream scratch (~BYTES_PER_NNZ - // per stored nonzero). Shrink sub_batch_cols until the densest window fits - // BOTH the int32 limit AND a per-stream slice of the budget, so very tall - // matrices neither overflow CUB nor exhaust memory. + // Each batch sorted in one CUB segmented call (int32 item count); its + // CSR->CSC transpose lives in per-stream scratch (~BYTES_PER_NNZ/nnz). + // Shrink sub_batch_cols until densest window fits BOTH the int32 limit AND + // a per-stream budget slice (tall matrices neither overflow CUB nor OOM). constexpr size_t BYTES_PER_NNZ = sizeof(InT) + sizeof(float) + 2 * sizeof(int) + 8; // buffers + CUB temp size_t batch_nnz_cap = SAFE_BATCH_NNZ; @@ -599,9 +584,8 @@ static void ovr_sparse_csr_host_streaming_impl( per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); } - // Stage the data array too when the indices are already resident and the - // data + at least one stream's transpose buffers still fit; the scatter - // then reads values at HBM speed. Otherwise the data stays mapped + // Stage data too when indices resident and data + one stream's transpose + // buffers fit (scatter reads values at HBM speed). Else data stays mapped // zero-copy (bounded for matrices too large to stage). size_t resident = indices_staged ? idx_bytes : 0; bool data_staged = total_nnz > 0 && indices_staged && @@ -758,13 +742,12 @@ static void ovr_sparse_csc_streaming_impl( bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; - // Read indptr to host for batch planning + // Read indptr to host for batch planning. std::vector h_indptr(n_cols + 1); cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(IndptrT), cudaMemcpyDeviceToHost); - // Bound each column batch's nnz so CUB item counts stay within int32 and - // the per-stream sort buffers fit the budget. + // Bound each batch's nnz: CUB item counts within int32 + sort buffers fit. { constexpr size_t BYTES_PER_NNZ = 2 * sizeof(float) + 2 * sizeof(int) + 8; @@ -844,8 +827,7 @@ static void ovr_sparse_csc_streaming_impl( int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), "OVR device CSC active batch nnz"); - // Compute rebased segment offsets on GPU (avoids host pinned-buffer - // race) + // Rebase segment offsets on GPU (avoids host pinned-buffer race). { int count = sb_cols + 1; int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; @@ -854,9 +836,9 @@ static void ovr_sparse_csc_streaming_impl( CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); } - // Sort only stored values (keys=data, vals=row_indices). Row indices - // always fit int32 (n_rows < 2^31); downcast int64 input here so the - // sort + rank stay int32 (half the val buffer) -- the device boundary. + // Sort stored values (keys=data, vals=row_indices). Row indices fit + // int32 (n_rows < 2^31); downcast int64 here so sort + rank stay int32 + // (half the val buffer) -- the device boundary. if (batch_nnz > 0) { const int* idx_src; if constexpr (sizeof(IndexT) > sizeof(int)) { @@ -905,14 +887,10 @@ static void ovr_sparse_csc_streaming_impl( /** * Sparse-aware OVR streaming pipeline for GPU CSR data. - * - * Phase 0: One histogram kernel counts nnz per column. D2H + host prefix sums - * give exact per-batch nnz and max_batch_nnz for buffer sizing. - * Phase 1: Allocate per-stream buffers sized to max_batch_nnz. - * Phase 2: For each sub-batch: scatter CSR→CSC (partial transpose via - * atomics) → CUB sort only nonzeros → sparse rank kernel. - * - * Compared to the dense CSR path, sort work drops by ~1/sparsity. + * P0: histogram nnz per column -> per-batch nnz + max_batch_nnz for sizing. + * P1: alloc per-stream buffers sized to max_batch_nnz. + * P2: per sub-batch scatter CSR->CSC (partial atomic transpose) -> CUB sort + * only nonzeros -> sparse rank. Sort work drops ~1/sparsity vs dense. */ template static void ovr_sparse_csr_streaming_impl( @@ -922,7 +900,7 @@ static void ovr_sparse_csr_streaming_impl( bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; - // ---- Phase 0: Planning — count nnz per column via histogram ---- + // ---- Phase 0: count nnz per column via histogram ---- RmmScratchPool pool; unsigned int* d_col_counts = pool.alloc(n_cols); cudaMemset(d_col_counts, 0, n_cols * sizeof(unsigned int)); @@ -936,8 +914,8 @@ static void ovr_sparse_csr_streaming_impl( cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(unsigned int), cudaMemcpyDeviceToHost); - // Bound each column batch's nnz so CUB item counts stay within int32 and - // the per-stream transpose/sort buffers fit the budget. + // Bound each batch's nnz: CUB item counts within int32 + transpose/sort + // buffers fit. { constexpr size_t BYTES_PER_NNZ = 2 * sizeof(float) + 2 * sizeof(int) + 8; @@ -950,11 +928,9 @@ static void ovr_sparse_csr_streaming_impl( [&](int c) { return (size_t)h_col_counts[c]; }); } - // Per-batch prefix sums on host + // Per-batch prefix sums on host; flat n_batches x (sub_batch_cols+1). int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; size_t max_batch_nnz = 0; - - // Flat array: n_batches × (sub_batch_cols + 1) offsets std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); std::vector h_batch_nnz(n_batches); @@ -971,13 +947,13 @@ static void ovr_sparse_csr_streaming_impl( if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; } - // Upload all batch offsets in one H2D (~20 KB) + // Upload all batch offsets in one H2D. int* d_all_offsets = pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); cudaMemcpy(d_all_offsets, h_all_offsets.data(), h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); - // ---- Phase 1: Allocate per-stream buffers ---- + // ---- Phase 1: per-stream buffers ---- size_t cub_temp_bytes = 0; if (max_batch_nnz > 0) { int max_batch_nnz_i32 = checked_cub_items( @@ -989,8 +965,8 @@ static void ovr_sparse_csr_streaming_impl( int n_streams = N_STREAMS; if (n_batches < n_streams) n_streams = n_batches; - // CSR path needs 4 sort arrays per stream (scatter intermediates + - // CUB output). Fit stream count to available GPU memory. + // CSR path needs 4 sort arrays per stream (scatter intermediates + CUB + // output); fit stream count to available GPU memory. bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); size_t per_stream_bytes = @@ -999,8 +975,8 @@ static void ovr_sparse_csr_streaming_impl( (size_t)n_groups * sub_batch_cols * sizeof(double) + sub_batch_cols * sizeof(double); if (rank_use_gmem) { - // gmem rank fallback (n_groups too large for smem): per-stream - // d_nz_scratch accumulator, same size as sub_rank_sums. + // gmem fallback (n_groups too large for smem): per-stream d_nz_scratch, + // same size as sub_rank_sums. per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); } @@ -1013,12 +989,12 @@ static void ovr_sparse_csr_streaming_impl( int scatter_blocks = (n_rows + tpb - 1) / tpb; struct StreamBuf { - int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets - int* write_pos; // [sub_batch_cols] atomic write counters - float* csc_vals; // [max_batch_nnz] transposed values - int* csc_row_idx; // [max_batch_nnz] transposed row indices - float* keys_out; // [max_batch_nnz] CUB sort output - int* vals_out; // [max_batch_nnz] CUB sort output + int* col_offsets; // CSC-style offsets + int* write_pos; // atomic write counters + float* csc_vals; // transposed values + int* csc_row_idx; // transposed row indices + float* keys_out; // CUB sort output + int* vals_out; // CUB sort output uint8_t* cub_temp; double* sub_rank_sums; double* sub_tie_corr; @@ -1044,7 +1020,7 @@ static void ovr_sparse_csr_streaming_impl( cudaDeviceSynchronize(); - // ---- Phase 2: Stream loop ---- + // ---- Phase 2: stream loop ---- int col = 0; for (int b = 0; b < n_batches; b++) { int sb_cols = std::min(sub_batch_cols, n_cols - col); @@ -1058,18 +1034,18 @@ static void ovr_sparse_csr_streaming_impl( cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), cudaMemcpyDeviceToDevice, stream); - // write_pos = col_offsets[0..sb_cols-1] (same D2D source) + // write_pos = col_offsets[0..sb_cols-1] (same D2D source). cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), cudaMemcpyDeviceToDevice, stream); if (batch_nnz > 0) { - // Scatter CSR → CSC for this sub-batch + // Scatter CSR -> CSC for this sub-batch. csr_scatter_to_csc_kernel<<>>( csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, buf.csc_row_idx, n_rows, col, col + sb_cols); CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); - // Sort only the nonzeros + // Sort only the nonzeros. cub_segmented_sortpairs(buf.cub_temp, cub_temp_bytes, buf.csc_vals, buf.keys_out, buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index 928b8af1..c24913bc 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -91,8 +91,8 @@ void register_sparse_bindings(nb::module_& m) { int); RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, int, int64_t); - // int64 row indices (int64 indptr): pass indices natively, downcast to - // int32 per-batch on-device rather than a full host int32 copy. + // int64 row indices: pass natively, downcast to int32 per-batch on-device + // (avoids a full host int32 copy). RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64_idx64", float, int64_t, int64_t); RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64_idx64", double, @@ -133,8 +133,8 @@ void register_sparse_bindings(nb::module_& m) { int); RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, int64_t); - // int64 column indices (int64 indptr): pass indices natively to avoid a - // full int32 copy of every nonzero (~nnz*4 bytes) on large matrices. + // int64 column indices: pass natively to avoid a full int32 copy of every + // nonzero (~nnz*4 bytes) on large matrices. RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64_idx64", float, int64_t, int64_t); RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64_idx64", double, @@ -216,8 +216,8 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, int64_t); - // int64 row indices: read natively (extraction only, never sorted) to skip - // the full host int32 copy. + // int64 row indices: read natively (extraction only, never sorted), + // skipping the full host int32 copy. RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64_idx64", float, int64_t, int64_t); RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64_idx64", double, @@ -260,8 +260,8 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, int64_t); - // int64 column indices (int64 indptr): pass indices natively to avoid a - // full int32 copy of every nonzero (~nnz*4 bytes) on large matrices. + // int64 column indices: pass natively to avoid a full int32 copy of every + // nonzero (~nnz*4 bytes) on large matrices. RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64_idx64", float, int64_t, int64_t); RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64_idx64", double, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index b9e232fe..73a51e74 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -5,44 +5,19 @@ #include "wilcoxon_block_reduce.cuh" #include "wilcoxon_ovr_tie_walk.cuh" -/** - * Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. - * - * Sparse rank_genes_groups now rejects explicit negative sparse values before - * reaching CUDA, so after CUB sort each column segment is: - * [stored_zeros..., positives...] - * - * Implicit zeros (n_rows - nnz_stored) join stored zeros as the first tie - * block. The kernel ranks only stored positive values and adds each group's - * zero contribution analytically. - * - * Full sorted array (conceptual): - * [ALL_zeros (stored+implicit)..., positives...] - * - * Rank offsets: - * positive at stored pos i : full pos = i + n_implicit_zero - * zeros : avg rank = (total_zero + 1) / 2 - * - * Shared-memory layout (doubles): - * grp_sums[n_groups] rank-sum accumulators - * grp_nz_count[n_groups] nonzero-per-group counters - * warp_buf[32] tie-correction reduction scratch - * - * n_rows is the ranking population, including rows whose group code is the - * n_groups sentinel. Sentinel rows contribute to the "rest" distribution and - * tie-correction denominator but do not receive rank-sum accumulation. - * - * Grid: (sb_cols,) Block: (tpb,) - */ -// HEADLINE sparse-OVR optimization (OVR-only). Ranks ONLY stored positive -// values; all zeros (stored + implicit n_rows-nnz) are treated as one leading -// tie block ranked analytically at (total_zero+1)/2, and each group's zero -// contribution is applied in closed form. Cost is O(nnz log nnz) per column, -// not O(n_rows log n_rows). The `use_gmem` flag selects shared- vs -// global-memory accumulators (see sparse_ovr_smem_config) -- CRITICAL: the -// use_gmem path is REQUIRED for large n_groups (Perturb-seq) and must not be -// removed. Validity relies on the upstream rejection of explicit negative -// sparse values, which guarantees zeros form the first tie block. +// Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. Ranks +// ONLY stored positives; all zeros (stored + implicit n_rows-nnz) form one +// leading tie block ranked analytically at (total_zero+1)/2, each group's zero +// contribution in closed form -> O(nnz log nnz)/col. Sentinel-group (n_groups) +// rows feed the rest/tie denominator but get no rank-sum accumulation. +// +// CRITICAL: validity relies on upstream rejection of negative sparse values +// (guarantees zeros form the first tie block). use_gmem selects shared- vs +// global-memory accumulators (sparse_ovr_smem_config), REQUIRED for large +// n_groups (Perturb-seq) -- do not remove. +// +// Grid (sb_cols,), Block (tpb,). Shared (doubles): grp_sums[n_groups] + +// grp_nz_count[n_groups] + warp_buf[32]. template __global__ void rank_sums_sparse_ovr_kernel( const float* __restrict__ sorted_vals, @@ -106,10 +81,8 @@ __global__ void rank_sums_sparse_ovr_kernel( int total_zero = n_implicit_zero + n_stored_zero; double zero_avg_rank = (total_zero > 0) ? (total_zero + 1.0) / 2.0 : 0.0; - // Rank offset for positive stored values: - // full_pos(i) = i + n_implicit_zero for i >= pos_start - // So avg_rank for tie group [a,b) of positives: - // = n_implicit_zero + (a + b + 1) / 2 + // Positive rank offset: full_pos(i)=i+n_implicit_zero; tie group [a,b) + // avg_rank = n_implicit_zero + (a+b+1)/2. int offset_pos = n_implicit_zero; // Count stored positives per group. @@ -168,12 +141,11 @@ __global__ void rank_sums_sparse_ovr_kernel( } } -// Shared sparse-OVR rank launch, used by all four sparse OVR impls (they differ -// only in how they produce the sorted nonzeros and how they scatter results). -// Optionally zeroes the global-memory accumulators, then launches the -// analytic-zero rank kernel. use_gmem is the CRITICAL large-n_groups / -// perturbation fallback (see sparse_ovr_smem_config) — DO NOT drop the gmem -// branch. ValT is the sorted-row-index type (int everywhere today). +// Shared sparse-OVR rank launch (all four sparse OVR impls). Optionally zeroes +// the gmem accumulators, then launches the analytic-zero rank kernel. use_gmem +// is the CRITICAL large-n_groups/perturbation fallback (see +// sparse_ovr_smem_config) — DO NOT drop the gmem branch. ValT is the +// sorted-row-index type (int everywhere today). template static inline void launch_ovr_sparse_rank( const float* sorted_vals, const ValT* sorted_row_idx, @@ -196,14 +168,10 @@ static inline void launch_ovr_sparse_rank( } // CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). -// -// Decide smem-vs-gmem for the sparse-OVR stats cast-and-accumulate kernel -// (sums / sq-sums / nnz). Needs n_arrays*n_groups doubles in smem; when that -// exceeds the per-block limit, use_gmem=true selects -// ovr_cast_and_accumulate_sparse_global_kernel, which accumulates directly in -// global memory. Same large-n_groups workloads that drive -// sparse_ovr_smem_config to gmem also drive this one; both fallbacks are -// load-bearing, not dead. +// smem-vs-gmem for the sparse-OVR stats cast+accumulate kernel. Needs +// n_arrays*n_groups doubles in smem; over the per-block limit, use_gmem=true +// selects ovr_cast_and_accumulate_sparse_global_kernel (accumulates in gmem). +// Load-bearing fallback, not dead. static size_t cast_accumulate_smem_config(int n_groups, bool compute_nnz, bool& use_gmem) { int n_arrays = 1 + (compute_nnz ? 1 : 0); @@ -217,9 +185,8 @@ static size_t cast_accumulate_smem_config(int n_groups, bool compute_nnz, } // Shared cast+accumulate loop for the two sparse-OVR stats kernels. Casts each -// stored value to f32 (data_f32_out) and atomically accumulates per-group sums -// (and nonzero counts) into sums/nnz, strided by acc_stride (1 for a per-block -// smem buffer, sb_cols for the global row-major layout). +// stored value to f32 and atomically accumulates per-group sums (+nnz), strided +// by acc_stride (1 for per-block smem, sb_cols for the gmem row-major layout). template __device__ __forceinline__ void accumulate_group_stats( const InT* data_in, float* data_f32_out, const IndexT* indices, @@ -240,16 +207,11 @@ __device__ __forceinline__ void accumulate_group_stats( } /** - * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. - * - * Sub-batch CSC data is laid out contiguously: values for column c live - * at positions [col_seg_offsets[c], col_seg_offsets[c+1]). For each - * stored value, read the native-dtype InT, write a float32 copy for the - * CUB sort, and accumulate per-group sum/sum-sq/nnz in float64. Implicit - * zeros contribute nothing to any of these stats. - * - * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). - * Shared memory: 3 * n_groups doubles. + * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. Sub-batch + * CSC column c lives at [col_seg_offsets[c], col_seg_offsets[c+1]); writes an + * f32 copy for the CUB sort and accumulates per-group sum/nnz in f64 (implicit + * zeros contribute nothing). Block-per-column (grid (sb_cols,), block (tpb,)), + * smem (1+compute_nnz)*n_groups doubles. */ template __global__ void ovr_cast_and_accumulate_sparse_kernel( @@ -264,8 +226,8 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( int seg_start = col_seg_offsets[col]; int seg_end = col_seg_offsets[col + 1]; - // Packed layout matching cast_accumulate_smem_config, which sizes the - // dynamic smem as (1 + compute_nnz) * n_groups doubles. + // Packed layout matching cast_accumulate_smem_config ((1+compute_nnz)* + // n_groups doubles). extern __shared__ double smem[]; double* s_sum = smem; double* s_nnz = smem + n_groups; @@ -289,11 +251,10 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( } } -// CRITICAL — DO NOT REMOVE. Global-memory variant of the stats accumulator, -// selected by cast_accumulate_smem_config when n_groups is too large for the -// smem version. Required for Perturb-seq-scale n_groups; the smem kernel cannot -// launch when its (n_arrays*n_groups) double buffer exceeds the per-block -// limit. +// CRITICAL — DO NOT REMOVE. Gmem variant of the stats accumulator, selected by +// cast_accumulate_smem_config when n_groups is too large for the smem kernel +// (its n_arrays*n_groups double buffer exceeds the per-block limit). Required +// for Perturb-seq-scale n_groups. template __global__ void ovr_cast_and_accumulate_sparse_global_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 4a5792c2..5ada25d3 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -276,6 +276,32 @@ def _ovr_z_pvals( ) +def _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + *, + use_continuity, + return_u_values, + n_groups, +): + """OVR epilogue: z/p-values -> host -> per-group (idx, scores, pvals).""" + scores, p_values = _ovr_z_pvals( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + scores_host = scores.get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + + def _ovo_z_pvals( rank_sums: cp.ndarray, test_sizes: cp.ndarray, @@ -302,6 +328,45 @@ def _ovo_z_pvals( ) +def _finish_ovo( + rank_sums, + test_sizes, + n_ref, + tie_corr_arr, + *, + tie_correct, + use_continuity, + return_u_values, + rg, + test_group_indices, + logfoldchanges_gpu, +): + """OVO epilogue: z/p-values; stash GPU result if requested, else host tuples.""" + scores, p_values = _ovo_z_pvals( + rank_sums, + test_sizes, + n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + logfoldchanges_gpu, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] + + def _host_sparse_fn_and_arrays(module, base_name: str, X): data_dtype = np.dtype(X.data.dtype) if data_dtype == np.float64: @@ -628,7 +693,7 @@ def _wilcoxon_vs_rest( total_nnz=total_nnz, ) - scores, p_values = _ovr_z_pvals( + return _finish_ovr( rank_sums, group_sizes_dev, rest_sizes, @@ -636,10 +701,8 @@ def _wilcoxon_vs_rest( tie_corr, use_continuity=use_continuity, return_u_values=return_u_values, + n_groups=n_groups, ) - scores_host = scores.get() - p_host = p_values.get() - return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] if ( cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X) @@ -686,7 +749,7 @@ def _wilcoxon_vs_rest( sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, ) - scores, p_values = _ovr_z_pvals( + return _finish_ovr( rank_sums, group_sizes_dev, rest_sizes, @@ -694,10 +757,8 @@ def _wilcoxon_vs_rest( tie_corr, use_continuity=use_continuity, return_u_values=return_u_values, + n_groups=n_groups, ) - scores_host = scores.get() - p_host = p_values.get() - return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) @@ -761,7 +822,7 @@ def _wilcoxon_vs_rest( total_sums=total_sums, total_nnz=total_nnz, ) - scores, p_values = _ovr_z_pvals( + return _finish_ovr( rank_sums, group_sizes_dev, rest_sizes, @@ -769,10 +830,8 @@ def _wilcoxon_vs_rest( tie_corr, use_continuity=use_continuity, return_u_values=return_u_values, + n_groups=n_groups, ) - scores_host = scores.get() - p_host = p_values.get() - return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) @@ -1021,7 +1080,7 @@ def _wilcoxon_with_reference( n_ref=n_ref, ) - scores, p_values = _ovo_z_pvals( + return _finish_ovo( rank_sums, test_sizes, n_ref, @@ -1029,21 +1088,10 @@ def _wilcoxon_with_reference( tie_correct=tie_correct, use_continuity=use_continuity, return_u_values=return_u_values, + rg=rg, + test_group_indices=test_group_indices, + logfoldchanges_gpu=logfoldchanges_gpu, ) - if rg._store_wilcoxon_gpu_result: - rg._wilcoxon_gpu_result = ( - np.asarray(test_group_indices, dtype=np.intp), - scores, - p_values, - logfoldchanges_gpu, - ) - return [] - scores_host = scores.get() - p_host = p_values.get() - return [ - (group_index, scores_host[slot], p_host[slot]) - for slot, group_index in enumerate(test_group_indices) - ] if ( cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X) @@ -1098,7 +1146,7 @@ def _wilcoxon_with_reference( sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, ) - scores, p_values = _ovo_z_pvals( + return _finish_ovo( rank_sums, test_sizes, n_ref, @@ -1106,21 +1154,10 @@ def _wilcoxon_with_reference( tie_correct=tie_correct, use_continuity=use_continuity, return_u_values=return_u_values, + rg=rg, + test_group_indices=test_group_indices, + logfoldchanges_gpu=None, ) - if rg._store_wilcoxon_gpu_result: - rg._wilcoxon_gpu_result = ( - np.asarray(test_group_indices, dtype=np.intp), - scores, - p_values, - None, - ) - return [] - scores_host = scores.get() - p_host = p_values.get() - return [ - (group_index, scores_host[slot], p_host[slot]) - for slot, group_index in enumerate(test_group_indices) - ] chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index a79917fd..33c9f2a4 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -39,12 +39,14 @@ def _make_nonnegative(adata): # full-sort path (correct for any sign), so the result matches running the same # method on the dense matrix. (t-test/t-test_overestim_var/logreg never need this # and accept signed sparse data directly -- e.g. mixscape's LDA t-test.) +# (method, reference) combos: vs-rest for wilcoxon + binned, plus the OVO +# (with-reference) wilcoxon path. binned has no with-reference mode. @pytest.mark.parametrize( - "method", - ["wilcoxon", "wilcoxon_binned"], + ("method", "reference"), + [("wilcoxon", "rest"), ("wilcoxon_binned", "rest"), ("wilcoxon", "b")], ) @pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) -def test_rank_genes_groups_sparse_negative_values_fallback(method, fmt): +def test_rank_genes_groups_sparse_negative_values_fallback(method, reference, fmt): X = np.array( [ [-1.0, 0.0, 2.0], @@ -63,8 +65,9 @@ def test_rank_genes_groups_sparse_negative_values_fallback(method, fmt): dense_fmt = "cupy_dense" if fmt.startswith("cupy") else "numpy_dense" dense_adata = sc.AnnData(X=_to_format(X, dense_fmt), obs=obs.copy(), var=var.copy()) - rsc.tl.rank_genes_groups(sparse_adata, "group", method=method, use_raw=False) - rsc.tl.rank_genes_groups(dense_adata, "group", method=method, use_raw=False) + kw = {"method": method, "reference": reference, "use_raw": False} + rsc.tl.rank_genes_groups(sparse_adata, "group", **kw) + rsc.tl.rank_genes_groups(dense_adata, "group", **kw) # Sparse-with-negatives falls back to the dense ranking -> identical result. sp_scores = sparse_adata.uns["rank_genes_groups"]["scores"] @@ -176,7 +179,10 @@ def test_rank_genes_groups_return_format_removed(): @pytest.mark.parametrize("reference", ["rest", "b"]) -@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_csr"]) +@pytest.mark.parametrize( + "fmt", + ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_dense", "cupy_csr", "cupy_csc"], +) def test_rank_genes_groups_wilcoxon_return_u_values(reference, fmt): X = np.array( [ @@ -1453,14 +1459,14 @@ def run(arr): ) -@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr"]) def test_wilcoxon_sparse_float16_data_raises(fmt): - # Unsupported float16 sparse data is rejected with a TypeError. + # Unsupported float16 sparse data (host + device) is rejected with TypeError. rng = np.random.default_rng(0) dense = np.abs(rng.standard_normal((40, 4))).astype(np.float32) - mat = sp.csr_matrix(dense) if fmt == "scipy_csr" else sp.csc_matrix(dense) - mat.data = mat.data.astype(np.float16) - assert mat.data.dtype == np.float16 + mat = _to_format(dense, fmt) + xp = cp if fmt.startswith("cupy") else np + mat.data = mat.data.astype(xp.float16) adata = sc.AnnData( X=mat, obs=pd.DataFrame({"group": pd.Categorical([f"{i % 2}" for i in range(40)])}), @@ -1470,100 +1476,6 @@ def test_wilcoxon_sparse_float16_data_raises(fmt): rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) -@pytest.mark.parametrize("reference", ["rest", "b"]) -@pytest.mark.parametrize("fmt", ["scipy_csc", "cupy_csc", "cupy_dense"]) -def test_rank_genes_groups_wilcoxon_return_u_values_more_formats(reference, fmt): - # U-value + continuity epilogue on CSC (host/device) and device-dense. - X = np.array( - [ - [5.0, 0.0, 1.0, 2.0], - [4.0, 0.0, 1.0, 2.0], - [1.0, 3.0, 2.0, 2.0], - [0.0, 2.0, 2.0, 2.0], - [2.0, 1.0, 0.0, 3.0], - [3.0, 1.0, 0.0, 3.0], - ], - dtype=np.float32, - ) - labels = np.array(["a", "a", "b", "b", "c", "c"]) - adata = sc.AnnData( - X=_to_format(X, fmt), - obs=pd.DataFrame({"group": pd.Categorical(labels)}), - var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), - ) - - rsc.tl.rank_genes_groups( - adata, - "group", - groups=["a"], - reference=reference, - method="wilcoxon", - use_raw=False, - tie_correct=True, - use_continuity=True, - return_u_values=True, - n_genes=adata.n_vars, - ) - - result = adata.uns["rank_genes_groups"] - assert result["params"]["return_u_values"] is True - assert result["scores"].dtype["a"] == np.dtype("float64") - - df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") - mask_group = labels == "a" - mask_ref = labels != "a" if reference == "rest" else labels == reference - expected = np.array( - [ - mannwhitneyu( - X[mask_group, gene], - X[mask_ref, gene], - alternative="two-sided", - ).statistic - for gene in range(X.shape[1]) - ], - dtype=np.float64, - ) - gene_to_idx = {name: idx for idx, name in enumerate(adata.var_names)} - expected_sorted = np.array([expected[gene_to_idx[name]] for name in df["names"]]) - np.testing.assert_allclose(df["scores"].to_numpy(), expected_sorted) - - -@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) -def test_rank_genes_groups_sparse_negative_values_fallback_ovo(fmt): - # Sparse-negative dense fallback on the OVO (with-reference) path. - X = np.array( - [ - [-1.0, 0.0, 2.0], - [0.0, 1.0, 0.0], - [2.0, 0.0, 1.0], - [0.0, 3.0, 0.0], - [-2.0, 1.0, 0.0], - [1.0, 0.0, 3.0], - ], - dtype=np.float64, - ) - obs = pd.DataFrame({"group": pd.Categorical(list("aaabbb"), categories=["a", "b"])}) - var = pd.DataFrame(index=["g0", "g1", "g2"]) - - sparse_adata = sc.AnnData(X=_to_format(X, fmt), obs=obs.copy(), var=var.copy()) - dense_fmt = "cupy_dense" if fmt.startswith("cupy") else "numpy_dense" - dense_adata = sc.AnnData(X=_to_format(X, dense_fmt), obs=obs.copy(), var=var.copy()) - - kw = {"groupby": "group", "method": "wilcoxon", "use_raw": False, "reference": "b"} - rsc.tl.rank_genes_groups(sparse_adata, **kw) - rsc.tl.rank_genes_groups(dense_adata, **kw) - - sp_scores = sparse_adata.uns["rank_genes_groups"]["scores"] - dn_scores = dense_adata.uns["rank_genes_groups"]["scores"] - for group in sp_scores.dtype.names: - np.testing.assert_allclose( - np.asarray(sp_scores[group], dtype=float), - np.asarray(dn_scores[group], dtype=float), - rtol=1e-13, - atol=1e-13, - ) - - @pytest.mark.parametrize("reference", ["rest", "2"]) def test_wilcoxon_group_subset_column_order_matches_scanpy(reference): """Output column order must echo the user's ``groups=`` list (scanpy parity), @@ -1793,19 +1705,17 @@ def _assert_ovo_matches_scanpy(adata, reference): ) -def test_ovo_tier_bands_medium_large_match_scanpy(): - """OVO dense-tiered path: small groups (20/50/300, all <= 512) run through - MEDIUM and a 1000-cell group through LARGE; must match scanpy.""" - adata = _anndata_with_group_sizes( - {"ref": 40, "g20": 20, "g50": 50, "g300": 300, "g1000": 1000}, seed=1 - ) - _assert_ovo_matches_scanpy(adata, reference="ref") - - -def test_ovo_tier_band_huge_match_scanpy(): - """OVO dense-tiered path must hit the HUGE band (CUB segmented sort, a - test-group > 2500 cells) and match scanpy.""" - adata = _anndata_with_group_sizes({"ref": 40, "huge": 3000}, seed=2) +@pytest.mark.parametrize( + ("sizes", "seed"), + [ + ({"ref": 40, "g20": 20, "g50": 50, "g300": 300, "g1000": 1000}, 1), + ({"ref": 40, "huge": 3000}, 2), + ], +) +def test_ovo_tier_bands_match_scanpy(sizes, seed): + """OVO dense-tiered path across MEDIUM/LARGE (groups <= 512 plus a 1000-cell + LARGE) and HUGE (a > 2500-cell group, CUB segmented sort); match scanpy.""" + adata = _anndata_with_group_sizes(sizes, seed=seed) _assert_ovo_matches_scanpy(adata, reference="ref") @@ -2289,20 +2199,3 @@ def test_host_csc_int64_indices_cast_matches_int32(reference): np.testing.assert_array_equal( np.asarray(r64[fld][grp]), np.asarray(r32[fld][grp]) ) - - -def test_device_sparse_float16_raises(): - """A cupy sparse matrix with float16 data raises a clear TypeError (the - device counterpart of the host float16 guard).""" - rng = np.random.default_rng(16) - dense = np.abs(rng.standard_normal((60, 4))).astype(np.float32) - dense[dense < 0.5] = 0.0 - mat = cpsp.csr_matrix(cp.asarray(dense)) - mat.data = mat.data.astype(cp.float16) - adata = sc.AnnData( - X=mat, - obs=pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(60)])}), - var=pd.DataFrame(index=[f"g{i}" for i in range(4)]), - ) - with pytest.raises(TypeError, match="float32"): - rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) From 8567329d93a8055fcae96e01f387d4a8c3057f4f Mon Sep 17 00:00:00 2001 From: Intron7 Date: Thu, 25 Jun 2026 12:24:00 +0200 Subject: [PATCH 32/36] add memory safety Signed-off-by: Intron7 --- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 10 +++++++ .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 28 +++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 72b0e1f5..1d07b780 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -103,6 +103,16 @@ constexpr int OVO_LARGE_MAX = 2500; // copy ≈ 1GB/stream. Sub-batching keeps (n_g * eff_sb_cols) <= this. constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; +// Budget-aware OVO-host pack sizing. Per-stream device scratch that does NOT +// scale with pack nnz: dense + sorted slabs (each <= GROUP_DENSE_BUDGET) plus +// rank/tie/seg/cub headroom. Reserved per target stream when bounding pack nnz +// so the resident packs + sorted ref cache fit device free. +constexpr size_t OVO_PACK_FIXED_PER_STREAM = + 4 * GROUP_DENSE_BUDGET_ITEMS * sizeof(float); // ~2 GB +// Floor for the budget-derived pack-nnz cap: avoid pathological over-splitting +// into thousands of tiny packs when device memory is very tight. +constexpr size_t OVO_MIN_PACK_NNZ = 64 * 1024 * 1024; // 64M nnz + // Host->device staging-ring slot cap (nnz). Bounds the page-locked footprint: // a pack's device buffer is filled in row-blocks of <= this many nonzeros, so // the cold pin stays small instead of seconds when pack nnz is large. 32M nnz diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 1d5a5105..00364e17 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -345,6 +345,29 @@ static void ovo_streaming_csr_host_impl( // counts + offsets); splits dense groups across more packs. constexpr size_t SAFE_PACK_NNZ = 1500000000; // < INT_MAX, CUB-safe + // Budget-aware pack-nnz cap: shrink packs so ~N_STREAMS of them plus + // the sorted ref cache and per-stream fixed scratch fit in 90% of free + // device memory. On a tight GPU this yields more, smaller packs (the + // per-stream device buffers then fit; the Phase-2 stream clamp sets the + // final stream count); on a big GPU it stays at the int32-safe cap, so + // large-GPU pack layout (and timing) is unchanged. Only the per-stream + // pack buffers scale with this; the ref cache is bounded separately. + size_t pack_nnz_cap = SAFE_PACK_NNZ; + { + int target_streams = std::min(N_STREAMS, n_test); + if (target_streams < 1) target_streams = 1; + size_t dev_budget = rmm_available_device_bytes(0.9); + size_t ref_bytes = (size_t)n_ref * (size_t)n_cols * sizeof(float); + size_t reserve = (size_t)target_streams * OVO_PACK_FIXED_PER_STREAM; + size_t grp_avail = + dev_budget > ref_bytes ? dev_budget - ref_bytes : 0; + size_t data_avail = grp_avail > reserve ? grp_avail - reserve : 0; + size_t cap = + data_avail / ((size_t)target_streams * 2 * sizeof(float)); + if (cap < OVO_MIN_PACK_NNZ) cap = OVO_MIN_PACK_NNZ; + if (cap < pack_nnz_cap) pack_nnz_cap = cap; + } + int cur_first = 0; int cur_rows = 0; size_t cur_nnz = 0; @@ -353,9 +376,8 @@ static void ovo_streaming_csr_host_impl( size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - h_grp_indptr_compact[h_grp_offsets[g]]); int new_rows = cur_rows + n_g; - bool can_add = - (cur_rows == 0) || - (new_rows <= target_rows && cur_nnz + nnz_g <= SAFE_PACK_NNZ); + bool can_add = (cur_rows == 0) || (new_rows <= target_rows && + cur_nnz + nnz_g <= pack_nnz_cap); if (!can_add) { size_t sb_size = std::min((size_t)n_cols, From 45d2a0b85b9dcb234c2bdc65c79f88a422b074c6 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Thu, 25 Jun 2026 15:42:55 +0200 Subject: [PATCH 33/36] update streaming Signed-off-by: Intron7 --- .../_cuda/streaming/streaming.cuh | 832 ++++++++++++++++++ .../_cuda/wilcoxon/wilcoxon.cu | 433 ++++++++- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 410 +-------- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 25 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 322 +++---- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 34 + .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 321 +++---- .../tools/_rank_genes_groups/_wilcoxon.py | 179 ++-- 8 files changed, 1713 insertions(+), 843 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/streaming/streaming.cuh diff --git a/src/rapids_singlecell/_cuda/streaming/streaming.cuh b/src/rapids_singlecell/_cuda/streaming/streaming.cuh new file mode 100644 index 00000000..fe9fcebf --- /dev/null +++ b/src/rapids_singlecell/_cuda/streaming/streaming.cuh @@ -0,0 +1,832 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../nb_types.h" +#include "../rmm_scratch.h" + +// Default thread-per-block for utility kernels shared by streaming pipelines. +constexpr int UTIL_BLOCK_SIZE = 256; +constexpr int DEFAULT_STREAMING_STREAMS = 4; +// Max per-batch nnz for segmented CUDA primitives that take int32 item counts. +constexpr size_t STREAMING_SAFE_BATCH_NNZ = 2000000000; // < INT_MAX +// Above this host span, avoid whole-array cudaHostRegister and use bounded +// staging. Moderate arrays keep the lower-overhead direct async-copy path. +constexpr size_t HOST_STREAMING_DIRECT_PIN_LIMIT_BYTES = + 16ULL * 1024ULL * 1024ULL * 1024ULL; + +// Host thread count for CPU-side staging passes: hardware concurrency, capped. +static inline int host_worker_count() { + unsigned hw = std::thread::hardware_concurrency(); + return (int)std::min(hw ? hw : 4u, 32u); +} + +// Run fn(chunk, r0, r1) over a partition of [0, n); `chunk` = 0-based worker +// index. fn runs concurrently: read-only shared state, disjoint output ranges +// (keyed by chunk or [r0,r1)). Returns chunks used; serial for small n. +template +static inline int host_parallel_chunks(int n, F fn) { + if (n <= 0) return 0; + int n_threads = host_worker_count(); + if (n_threads <= 1 || n < 4096) { + fn(0, 0, n); + return 1; + } + int chunk = (n + n_threads - 1) / n_threads; + std::vector pool; + pool.reserve(n_threads); + for (int t = 0; t < n_threads; t++) { + int r0 = t * chunk; + if (r0 >= n) break; + int r1 = std::min(n, r0 + chunk); + pool.emplace_back([&fn, t, r0, r1]() { fn(t, r0, r1); }); + } + int used = (int)pool.size(); + for (std::thread& th : pool) th.join(); + return used; +} + +// Run fn(r0, r1) over a partition of [0, n) across hardware threads (serial for +// small n). Concurrent: read-only shared state, disjoint output ranges. +template +static inline void host_parallel_ranges(int n, F fn) { + host_parallel_chunks(n, [&fn](int, int r0, int r1) { fn(r0, r1); }); +} + +static inline int checked_cub_items(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds CUB int item limit"); + } + return (int)count; +} + +static inline int checked_int_span(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds int32 offset limit"); + } + return (int)count; +} + +static inline int checked_int_product(size_t a, size_t b, const char* context) { + if (a != 0 && b > (size_t)std::numeric_limits::max() / a) { + throw std::runtime_error(std::string(context) + + " exceeds int32 item limit"); + } + return (int)(a * b); +} + +template +struct SparseWindowDTypes { + using value_type = DeviceValueT; + using index_type = DeviceIndexT; + using accum_type = AccumT; + + static constexpr size_t bytes_per_nnz = + sizeof(value_type) + sizeof(index_type); +}; + +using WilcoxonSparseWindowDTypes = SparseWindowDTypes; + +template +static inline size_t sparse_window_nnz_bytes(size_t nnz) { + return nnz * DTypes::bytes_per_nnz; +} + +template +static inline size_t sparse_window_accum_bytes(size_t count) { + return count * sizeof(typename DTypes::accum_type); +} + +static inline void host_clear_id_map(int* id_map, int n_items) { + std::fill(id_map, id_map + n_items, -1); +} + +static inline void host_build_id_map(const int* ids, int n_ids, int* id_map, + int n_items, const char* what) { + host_clear_id_map(id_map, n_items); + for (int local = 0; local < n_ids; local++) { + int id = ids[local]; + if (id < 0 || id >= n_items) { + throw std::runtime_error(std::string(what) + + " id is out of bounds"); + } + id_map[id] = local; + } +} + +static inline void host_build_contiguous_id_map(int first, int count, + int* id_map, int n_items, + const char* what) { + if (first < 0 || count < 0 || first > n_items - count) { + throw std::runtime_error(std::string(what) + + " contiguous id window is out of bounds"); + } + host_clear_id_map(id_map, n_items); + for (int local = 0; local < count; local++) id_map[first + local] = local; +} + +// Stream-count clamps: never use more streams than column batches, nor more +// than the per-stream memory budget allows. +static inline int clamp_streams_by_cols( + int n_cols, int sub_batch_cols, + int max_streams = DEFAULT_STREAMING_STREAMS) { + int n = max_streams; + if (n_cols < n * sub_batch_cols) + n = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + return n; +} + +static inline int clamp_streams_by_budget(int n_streams, + size_t per_stream_bytes, + size_t budget) { + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + return n_streams; +} + +// Scatter a [rows, sb_cols] device sub-batch (row-major doubles, src stride +// sb_cols) into `dst` (stride n_cols). `dst` must point at the dest column +// offset (e.g. out + col). +static inline void scatter_cols_2d(double* dst, const double* src, int rows, + int n_cols, int sb_cols, + cudaStream_t stream) { + cudaMemcpy2DAsync(dst, n_cols * sizeof(double), src, + sb_cols * sizeof(double), sb_cols * sizeof(double), rows, + cudaMemcpyDeviceToDevice, stream); +} + +// Halve sub_batch_cols until the densest window holds <= cap nonzeros, keeping +// every batch's nnz within int32 for CUB and bounding per-stream transpose/sort +// scratch. col_nnz(i) = nnz of column i. Worst case returns 1 (single column, +// nnz <= n_rows). +template +static inline int cap_sub_batch_by_nnz(int n_cols, int sub_batch_cols, + size_t cap, ColNnz col_nnz) { + if (cap < 1) cap = 1; + auto max_window = [&](int s) { + size_t mx = 0; + for (int c = 0; c < n_cols; c += s) { + int e = std::min(c + s, n_cols); + size_t sum = 0; + for (int i = c; i < e; i++) sum += col_nnz(i); + if (sum > mx) mx = sum; + } + return mx; + }; + while (sub_batch_cols > 1 && max_window(sub_batch_cols) > cap) + sub_batch_cols = (sub_batch_cols + 1) / 2; + return sub_batch_cols; +} + +struct ColumnBatchPlan { + int sub_batch_cols = 0; + int n_batches = 0; + size_t max_nnz = 0; + std::vector offsets; + std::vector nnz; +}; + +struct HostCompactSparseWindowPlan { + int major_count = 0; + size_t nnz = 0; + std::vector indptr; +}; + +struct DenseColumnBatchPlan { + int sub_batch_cols = 0; + int n_batches = 0; + size_t max_items = 0; +}; + +static inline DenseColumnBatchPlan plan_dense_column_batches( + int n_rows, int n_cols, int sub_batch_cols, size_t cap, const char* what) { + DenseColumnBatchPlan plan; + if (sub_batch_cols < 1) sub_batch_cols = 1; + if (cap < 1) cap = 1; + checked_cub_items((size_t)n_rows, what); + + size_t max_cols = + n_rows > 0 ? cap / (size_t)n_rows : (size_t)sub_batch_cols; + if (max_cols < 1) max_cols = 1; + if ((size_t)sub_batch_cols > max_cols) sub_batch_cols = (int)max_cols; + + plan.sub_batch_cols = sub_batch_cols; + plan.n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + plan.max_items = (size_t)n_rows * (size_t)sub_batch_cols; + checked_cub_items(plan.max_items, what); + return plan; +} + +template +static inline ColumnBatchPlan plan_column_batches_from_counts( + int n_cols, int sub_batch_cols, size_t cap, CountAt count_at, + const char* what) { + ColumnBatchPlan plan; + plan.sub_batch_cols = + cap_sub_batch_by_nnz(n_cols, sub_batch_cols, cap, count_at); + plan.n_batches = (n_cols + plan.sub_batch_cols - 1) / plan.sub_batch_cols; + plan.offsets.assign((size_t)plan.n_batches * (plan.sub_batch_cols + 1), 0); + plan.nnz.assign(plan.n_batches, 0); + for (int b = 0; b < plan.n_batches; b++) { + int col_start = b * plan.sub_batch_cols; + int sb = std::min(plan.sub_batch_cols, n_cols - col_start); + int* off = &plan.offsets[(size_t)b * (plan.sub_batch_cols + 1)]; + for (int i = 0; i < sb; i++) { + off[i + 1] = checked_int_span( + (size_t)off[i] + (size_t)count_at(col_start + i), what); + } + plan.nnz[b] = (size_t)off[sb]; + if (plan.nnz[b] > plan.max_nnz) plan.max_nnz = plan.nnz[b]; + } + return plan; +} + +template +static inline ColumnBatchPlan plan_csc_column_batches(const IndptrT* h_indptr, + int n_cols, + int sub_batch_cols, + size_t cap, + const char* what) { + return plan_column_batches_from_counts( + n_cols, sub_batch_cols, cap, + [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }, what); +} + +static inline int* upload_batch_offsets(const ColumnBatchPlan& plan, + RmmScratchPool& pool) { + int* d_all_offsets = pool.alloc(plan.offsets.size()); + cudaMemcpy(d_all_offsets, plan.offsets.data(), + plan.offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + return d_all_offsets; +} + +template +static HostCompactSparseWindowPlan plan_compact_sparse_window( + int major_count, CountAt count_at, const char* what) { + HostCompactSparseWindowPlan plan; + plan.major_count = major_count; + plan.indptr.assign((size_t)major_count + 1, 0); + if (major_count <= 0) return plan; + + std::vector counts(major_count, 0); + host_parallel_ranges(major_count, [&](int i0, int i1) { + for (int i = i0; i < i1; i++) counts[i] = count_at(i); + }); + + size_t run = 0; + for (int i = 0; i < major_count; i++) { + plan.indptr[i] = checked_int_span(run, what); + run += counts[i]; + } + plan.indptr[major_count] = checked_int_span(run, what); + plan.nnz = run; + return plan; +} + +template +static HostCompactSparseWindowPlan plan_csc_rows_window( + const IndexT* h_indices, const IndptrT* h_indptr, int col_start, + int n_window_cols, RowToLocal row_to_local, const char* what) { + return plan_compact_sparse_window( + n_window_cols, + [&](int local_col) { + int col = col_start + local_col; + size_t count = 0; + for (IndptrT p = h_indptr[col]; p < h_indptr[col + 1]; p++) { + if (row_to_local((int)h_indices[p]) >= 0) count++; + } + return count; + }, + what); +} + +template +static HostCompactSparseWindowPlan plan_csr_cols_window( + const IndexT* h_indices, const IndptrT* h_indptr, const int* row_ids, + int n_window_rows, ColToLocal col_to_local, const char* what) { + return plan_compact_sparse_window( + n_window_rows, + [&](int local_row) { + int row = row_ids ? row_ids[local_row] : local_row; + size_t count = 0; + for (IndptrT p = h_indptr[row]; p < h_indptr[row + 1]; p++) { + if (col_to_local((int)h_indices[p]) >= 0) count++; + } + return count; + }, + what); +} + +template +static HostCompactSparseWindowPlan plan_csc_rows_window_from_map( + const IndexT* h_indices, const IndptrT* h_indptr, int col_start, + int n_window_cols, const int* row_map, const char* what) { + return plan_csc_rows_window( + h_indices, h_indptr, col_start, n_window_cols, + [&](int row) { return row_map[row]; }, what); +} + +template +static HostCompactSparseWindowPlan plan_csr_cols_window_from_map( + const IndexT* h_indices, const IndptrT* h_indptr, const int* row_ids, + int n_window_rows, const int* col_map, const char* what) { + return plan_csr_cols_window( + h_indices, h_indptr, row_ids, n_window_rows, + [&](int col) { return col_map[col]; }, what); +} + +// RAII guard for cudaHostRegister: unregisters on scope exit (incl. exception +// unwind), preventing leaked host pinning on stream-sync failures. +struct HostRegisterGuard { + void* ptr = nullptr; + + HostRegisterGuard() = default; + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0, + bool best_effort = false) { + if (p && bytes > 0) { + cudaError_t err = cudaHostRegister(p, bytes, flags); + if (err != cudaSuccess) { + // Already-registered = owned elsewhere; use it without + // unregistering. Other failures make mapped reads unsafe, so + // surface them -- unless best_effort (pin is only a speedup; + // unpinned H2D still works). + if (err == cudaErrorHostMemoryAlreadyRegistered || + best_effort) { + cudaGetLastError(); // clear sticky error flag + } else { + throw std::runtime_error( + std::string("cudaHostRegister failed (") + + std::to_string((size_t)bytes) + + " bytes, flags=" + std::to_string(flags) + + "): " + cudaGetErrorString(err)); + } + } else { + ptr = p; + } + } + } + ~HostRegisterGuard() { + if (ptr) cudaHostUnregister(ptr); + } + HostRegisterGuard(const HostRegisterGuard&) = delete; + HostRegisterGuard& operator=(const HostRegisterGuard&) = delete; + HostRegisterGuard(HostRegisterGuard&& other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; + } + HostRegisterGuard& operator=(HostRegisterGuard&& other) noexcept { + if (this != &other) { + if (ptr) cudaHostUnregister(ptr); + ptr = other.ptr; + other.ptr = nullptr; + } + return *this; + } +}; + +// RAII for CUDA streams/events: reclaim on every path (incl. exception unwind). +// Stream dtor SYNCHRONIZES before destroying. CRITICAL ordering: declare the +// RmmScratchPool BEFORE these guards so streams (destroyed first) drain +// in-flight kernels before the pool (destroyed last) frees the scratch they +// read. +struct ScopedCudaStream { + cudaStream_t stream = nullptr; + + ScopedCudaStream() = default; + explicit ScopedCudaStream(unsigned int flags) { + cuda_check(cudaStreamCreateWithFlags(&stream, flags), + "cudaStreamCreateWithFlags"); + } + ~ScopedCudaStream() { + if (stream) { + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + } + } + operator cudaStream_t() const { + return stream; + } + cudaStream_t get() const { + return stream; + } + ScopedCudaStream(const ScopedCudaStream&) = delete; + ScopedCudaStream& operator=(const ScopedCudaStream&) = delete; +}; + +struct ScopedCudaStreams { + std::vector streams; + + // `flags` is explicit so call sites keep their original stream semantics. + ScopedCudaStreams(int n, unsigned int flags) { + streams.reserve(n > 0 ? (size_t)n : 0); + for (int i = 0; i < n; ++i) { + cudaStream_t s = nullptr; + cudaError_t err = cudaStreamCreateWithFlags(&s, flags); + if (err != cudaSuccess) { + // dtor won't run on ctor throw; reclaim what we made. + for (cudaStream_t prev : streams) { + cudaStreamSynchronize(prev); + cudaStreamDestroy(prev); + } + throw std::runtime_error( + std::string("cudaStreamCreateWithFlags failed: ") + + cudaGetErrorString(err)); + } + streams.push_back(s); + } + } + ~ScopedCudaStreams() { + for (cudaStream_t s : streams) { + if (!s) continue; + cudaStreamSynchronize(s); + cudaStreamDestroy(s); + } + } + cudaStream_t operator[](int i) const { + return streams[i]; + } + int size() const { + return (int)streams.size(); + } + ScopedCudaStreams(const ScopedCudaStreams&) = delete; + ScopedCudaStreams& operator=(const ScopedCudaStreams&) = delete; +}; + +// Drain every stream, surfacing the first async error with a context label. +static inline void sync_streams(const ScopedCudaStreams& streams, + const char* what) { + for (int i = 0; i < streams.size(); ++i) { + cudaError_t err = cudaStreamSynchronize(streams[i]); + if (err != cudaSuccess) + throw std::runtime_error(std::string("CUDA error in ") + what + + ": " + cudaGetErrorString(err)); + } +} + +struct ScopedCudaEvent { + cudaEvent_t event = nullptr; + + ScopedCudaEvent() = default; + explicit ScopedCudaEvent(unsigned int flags) { + cuda_check(cudaEventCreateWithFlags(&event, flags), + "cudaEventCreateWithFlags"); + } + ~ScopedCudaEvent() { + if (event) cudaEventDestroy(event); + } + void record(cudaStream_t stream) { + cuda_check(cudaEventRecord(event, stream), "cudaEventRecord"); + } + cudaEvent_t get() const { + return event; + } + ScopedCudaEvent(const ScopedCudaEvent&) = delete; + ScopedCudaEvent& operator=(const ScopedCudaEvent&) = delete; +}; + +template +struct PinnedRingArray { + std::vector> data; + std::vector pins; + + PinnedRingArray() = default; + PinnedRingArray(int n_slots, size_t count) : data(n_slots), pins(n_slots) { + size_t n = count ? count : 1; + for (int s = 0; s < n_slots; s++) { + data[s].reset(new T[n]); + pins[s] = HostRegisterGuard(data[s].get(), n * sizeof(T)); + } + } + T* get(int slot) { + return data[slot].get(); + } + const T* get(int slot) const { + return data[slot].get(); + } +}; + +// Per-slot pinned host staging with events, so CPU materialization into one +// slot can overlap GPU consumption of another. All arrays have the same item +// capacity; use a second ring for differently-sized metadata such as indptr. +template +struct PinnedRing { + std::tuple...> arrays; + std::vector evt; + std::vector used; + int n_slots = 0; + size_t capacity = 0; + + PinnedRing(int n_slots_, size_t count) + : arrays(PinnedRingArray(n_slots_, count)...), + evt(n_slots_, nullptr), + used(n_slots_, 0) { + n_slots = n_slots_; + capacity = count ? count : 1; + for (int s = 0; s < n_slots; s++) { + cuda_check( + cudaEventCreateWithFlags(&evt[s], cudaEventDisableTiming), + "PinnedRing event create"); + } + } + ~PinnedRing() { + for (size_t s = 0; s < evt.size(); ++s) { + cudaEvent_t e = evt[s]; + if (!e) continue; + if (s < used.size() && used[s]) cudaEventSynchronize(e); + cudaEventDestroy(e); + } + } + void wait(int s) { + if (used[s]) + cuda_check(cudaEventSynchronize(evt[s]), "PinnedRing reuse"); + } + void record(int s, cudaStream_t stream) { + cuda_check(cudaEventRecord(evt[s], stream), "PinnedRing record"); + used[s] = true; + } + template + typename std::tuple_element>::type* get(int slot) { + return std::get(arrays).get(slot); + } + template + const typename std::tuple_element>::type* get( + int slot) const { + return std::get(arrays).get(slot); + } + PinnedRing(const PinnedRing&) = delete; + PinnedRing& operator=(const PinnedRing&) = delete; +}; + +template +using SparseWindowStagingRing = + PinnedRing; + +using HostStagingRing = SparseWindowStagingRing; + +/** Fill linear segment offsets [0, stride, ..., n_segments*stride] on-device. + */ +__global__ void fill_linear_offsets_kernel(int* __restrict__ out, + int n_segments, int stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i <= n_segments) out[i] = i * stride; +} + +/** Rebase a slice of indptr: out[i] = indptr[col+i] - indptr[col]. Grid-strided + * (arbitrary `count`). Templated so 64-bit global indptrs produce 32-bit + * pack-local indptrs (per-pack nnz fits int32 via the memory budget). */ +template +__global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, + IdxOut* __restrict__ out, int col, + int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); +} + +// Threaded host gather of selected rows into compact staging (f32 vals + int32 +// cols) at disjoint per-row offsets (compact_indptr - base) -> race-free. +// No-pin alternative to the mapped gather kernel: only the compacted slice +// crosses the bus. +template +static void host_gather_rows_compact_as( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* row_ids, const CompactT* compact_indptr, CompactT base, + int n_target, StageValT* stage_vals, StageIndexT* stage_cols) { + host_parallel_ranges(n_target, [&](int i0, int i1) { + for (int i = i0; i < i1; i++) { + int r = row_ids[i]; + IndptrT rs = h_indptr[r]; + int nnz = (int)(h_indptr[r + 1] - rs); + size_t ds = (size_t)(compact_indptr[i] - base); + for (int k = 0; k < nnz; k++) { + stage_vals[ds + k] = (StageValT)h_data[rs + k]; + stage_cols[ds + k] = (StageIndexT)h_indices[rs + k]; + } + } + }); +} + +template +static void host_gather_rows_compact(const InT* h_data, const IndexT* h_indices, + const IndptrT* h_indptr, + const int* row_ids, + const CompactT* compact_indptr, + CompactT base, int n_target, + float* stage_vals, int* stage_cols) { + host_gather_rows_compact_as(h_data, h_indices, h_indptr, + row_ids, compact_indptr, base, + n_target, stage_vals, stage_cols); +} + +// Threaded host cast-copy of a contiguous nnz slice into staging (f32 + int32). +// CSC analogue of host_gather_rows_compact: contiguous column batch, no gather. +// nnz fits int32 (batch-bounded). +template +static void host_copy_slice_as(const InT* h_data, const IndexT* h_indices, + size_t start, int nnz, StageValT* stage_vals, + StageIndexT* stage_cols) { + host_parallel_ranges(nnz, [&](int k0, int k1) { + for (int k = k0; k < k1; k++) { + stage_vals[k] = (StageValT)h_data[start + k]; + stage_cols[k] = (StageIndexT)h_indices[start + k]; + } + }); +} + +template +static void host_copy_slice(const InT* h_data, const IndexT* h_indices, + size_t start, int nnz, InT* stage_vals, + IndexT* stage_cols) { + host_copy_slice_as(h_data, h_indices, start, nnz, stage_vals, + stage_cols); +} + +template +static void host_cast_copy_slice(const InT* h_data, const IndexT* h_indices, + size_t start, int nnz, float* stage_vals, + int* stage_cols) { + host_copy_slice_as(h_data, h_indices, start, nnz, stage_vals, + stage_cols); +} + +// Threaded host gather of selected dense rows and a contiguous column window +// into an F-order staging tile. f_order describes the full source matrix; the +// staged output is always [n_window_rows, n_window_cols] in column-major order. +template +static void host_materialize_dense_rows_window_as( + const InT* h_X, bool f_order, int n_full_rows, int n_full_cols, + const int* row_ids, int n_window_rows, int col_start, int n_window_cols, + StageT* stage) { + int total = + checked_int_product((size_t)n_window_rows, (size_t)n_window_cols, + "dense host row-window items"); + host_parallel_ranges(total, [&](int i0, int i1) { + for (int idx = i0; idx < i1; idx++) { + int local_col = idx / n_window_rows; + int local_row = idx - local_col * n_window_rows; + int row = row_ids ? row_ids[local_row] : local_row; + int col = col_start + local_col; + size_t src = f_order ? (size_t)col * n_full_rows + row + : (size_t)row * n_full_cols + col; + stage[(size_t)local_col * n_window_rows + local_row] = + (StageT)h_X[src]; + } + }); +} + +template +static void host_materialize_dense_rows_window(const InT* h_X, bool f_order, + int n_full_rows, int n_full_cols, + const int* row_ids, + int n_window_rows, int col_start, + int n_window_cols, InT* stage) { + host_materialize_dense_rows_window_as( + h_X, f_order, n_full_rows, n_full_cols, row_ids, n_window_rows, + col_start, n_window_cols, stage); +} + +// Cross-axis CSC materialization: filter a contiguous column window by selected +// rows and emit compact CSC with local row ids. +template +static void host_materialize_csc_rows_window_as( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int col_start, int n_window_cols, const int* compact_indptr, + RowToLocal row_to_local, StageValT* stage_vals, StageIndexT* stage_rows) { + host_parallel_ranges(n_window_cols, [&](int c0, int c1) { + for (int local_col = c0; local_col < c1; local_col++) { + int col = col_start + local_col; + size_t dst = (size_t)compact_indptr[local_col]; + for (IndptrT p = h_indptr[col]; p < h_indptr[col + 1]; p++) { + int local_row = row_to_local((int)h_indices[p]); + if (local_row < 0) continue; + stage_vals[dst] = (StageValT)h_data[p]; + stage_rows[dst] = (StageIndexT)local_row; + dst++; + } + } + }); +} + +template +static void host_materialize_csc_rows_window( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int col_start, int n_window_cols, const int* compact_indptr, + const int* row_map, float* stage_vals, int* stage_rows) { + host_materialize_csc_rows_window_as( + h_data, h_indices, h_indptr, col_start, n_window_cols, compact_indptr, + [&](int row) { return row_map[row]; }, stage_vals, stage_rows); +} + +// Cross-axis CSR materialization: filter selected rows by selected columns and +// emit compact CSR with local column ids. +template +static void host_materialize_csr_cols_window_as( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* row_ids, int n_window_rows, const int* compact_indptr, + ColToLocal col_to_local, StageValT* stage_vals, StageIndexT* stage_cols) { + host_parallel_ranges(n_window_rows, [&](int r0, int r1) { + for (int local_row = r0; local_row < r1; local_row++) { + int row = row_ids ? row_ids[local_row] : local_row; + size_t dst = (size_t)compact_indptr[local_row]; + for (IndptrT p = h_indptr[row]; p < h_indptr[row + 1]; p++) { + int local_col = col_to_local((int)h_indices[p]); + if (local_col < 0) continue; + stage_vals[dst] = (StageValT)h_data[p]; + stage_cols[dst] = (StageIndexT)local_col; + dst++; + } + } + }); +} + +template +static void host_materialize_csr_cols_window( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* row_ids, int n_window_rows, const int* compact_indptr, + const int* col_map, float* stage_vals, int* stage_cols) { + host_materialize_csr_cols_window_as( + h_data, h_indices, h_indptr, row_ids, n_window_rows, compact_indptr, + [&](int col) { return col_map[col]; }, stage_vals, stage_cols); +} + +// Optimized CSR -> contiguous-column-window materialization for sorted rows and +// ascending column batches. The per-row cursor means each nonzero is examined +// once across the full stream. +template +static int host_materialize_csr_column_interval_cursor_as( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_rows, int col_start, int col_end, IndptrT* cursor, int* row_counts, + int* compact_indptr, StageValT* stage_vals, StageIndexT* stage_cols, + const char* what) { + host_parallel_ranges(n_rows, [&](int r0, int r1) { + for (int r = r0; r < r1; r++) { + const IndexT* row_base = h_indices + h_indptr[r]; + const IndexT* lo = row_base + cursor[r]; + const IndexT* hi = h_indices + h_indptr[r + 1]; + if (lo < hi && *lo < (IndexT)col_start) { + lo = std::lower_bound(lo, hi, (IndexT)col_start); + cursor[r] = (IndptrT)(lo - row_base); + } + row_counts[r] = + (int)(std::lower_bound(lo, hi, (IndexT)col_end) - lo); + } + }); + + compact_indptr[0] = 0; + for (int r = 0; r < n_rows; r++) { + compact_indptr[r + 1] = checked_int_span( + (size_t)compact_indptr[r] + (size_t)row_counts[r], what); + } + int batch_nnz = compact_indptr[n_rows]; + + host_parallel_ranges(n_rows, [&](int r0, int r1) { + for (int r = r0; r < r1; r++) { + IndptrT base = h_indptr[r] + cursor[r]; + size_t dst = (size_t)compact_indptr[r]; + int count = row_counts[r]; + for (int k = 0; k < count; k++) { + stage_vals[dst + k] = (StageValT)h_data[base + k]; + stage_cols[dst + k] = (StageIndexT)h_indices[base + k]; + } + cursor[r] += count; + } + }); + return batch_nnz; +} + +template +static int host_materialize_csr_column_interval_cursor( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_rows, int col_start, int col_end, IndptrT* cursor, int* row_counts, + int* compact_indptr, InT* stage_vals, int* stage_cols, const char* what) { + return host_materialize_csr_column_interval_cursor_as( + h_data, h_indices, h_indptr, n_rows, col_start, col_end, cursor, + row_counts, compact_indptr, stage_vals, stage_cols, what); +} + +/** Fill linear segment offsets [0, stride, ...] on the supplied stream (avoids + * serializing multi-stream pipelines). */ +static inline void upload_linear_offsets(int* d_offsets, int n_segments, + int stride, cudaStream_t stream) { + int count = n_segments + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_linear_offsets_kernel<<>>( + d_offsets, n_segments, stride); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index e7419c7b..74d75574 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -23,12 +23,11 @@ static void launch_ovr_rank_dense_streaming( if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) { - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - } - - size_t sub_items = (size_t)n_rows * sub_batch_cols; + DenseColumnBatchPlan batches = plan_dense_column_batches( + n_rows, n_cols, sub_batch_cols, SAFE_BATCH_NNZ, "Dense OVR sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + size_t sub_items = batches.max_items; int sub_items_i32 = checked_cub_items(sub_items, "Dense OVR sub-batch"); size_t cub_temp_bytes = @@ -141,12 +140,12 @@ static void launch_ovr_rank_dense_host_streaming( // F-order float32 input feeds the sort directly (no cast/transpose buffer). const bool fast_keys = f_order && std::is_same::value; - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) { - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - } - - size_t sub_items = (size_t)n_rows * sub_batch_cols; + DenseColumnBatchPlan batches = + plan_dense_column_batches(n_rows, n_cols, sub_batch_cols, + SAFE_BATCH_NNZ, "Dense host OVR sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + size_t sub_items = batches.max_items; int sub_items_i32 = checked_cub_items(sub_items, "Dense host OVR sub-batch"); size_t cub_temp_bytes = @@ -347,9 +346,11 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( n_sort_groups = (int)h_sort_group_ids.size(); } - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + DenseColumnBatchPlan batches = plan_dense_column_batches( + std::max(n_ref, n_all_grp), n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "Dense OVO sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; int sub_ref_items_i32 = @@ -370,6 +371,20 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( size_t ref_cub_temp_bytes = cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); + { + size_t per_stream = + sub_ref_items * sizeof(float) + + (size_t)(sub_batch_cols + 1) * sizeof(int) + ref_cub_temp_bytes + + (run_huge ? sub_grp_items * sizeof(float) : 0) + + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) + : 0) + + (run_huge ? grp_cub_temp_bytes : 0) + + (compute_tie_corr ? (size_t)sub_batch_cols * sizeof(double) : 0) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double); + n_streams = clamp_streams_by_budget(n_streams, per_stream, + rmm_available_device_bytes(0.8)); + } + // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); @@ -485,6 +500,327 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref( sync_streams(streams, "dense OVO tiered rank"); } +template +static void launch_ovo_rank_dense_host_streaming( + const T* h_X, bool f_order, const int* h_ref_row_ids, + const int* h_grp_row_ids, const int* h_grp_offsets, double* rank_sums, + double* tie_corr, double* group_sums, double* group_sum_sq, + double* group_nnz, int n_full_rows, int n_ref, int n_all_grp, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, bool compute_nnz, + bool compute_stats, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + if (compute_stats && n_groups_stats != n_groups + 1) { + throw std::runtime_error( + "dense OVO host stats require n_groups_stats == n_groups + 1"); + } + if (h_grp_offsets[0] != 0 || h_grp_offsets[n_groups] != n_all_grp) { + throw std::runtime_error( + "dense OVO host group offsets must span n_all_grp"); + } + + auto tier_plan = make_ovo_tier_plan(h_grp_offsets, n_groups); + int max_grp_size = tier_plan.max_grp_size; + bool run_large = tier_plan.above_medium && tier_plan.run_large; + bool run_huge = tier_plan.above_medium && !run_large; + + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, OVO_MEDIUM_MAX); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + DenseColumnBatchPlan batches = plan_dense_column_batches( + std::max(n_ref, n_all_grp), n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "Dense host OVO sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "Dense host OVO reference sub-batch"); + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "Dense host OVO group sub-batch"); + constexpr bool fast_keys = std::is_same::value; + int n_stats_rows = n_groups + 1; + + size_t grp_cub_temp_bytes = 0; + if (run_huge) { + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "Dense host OVO group segment count"); + grp_cub_temp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); + } + size_t ref_cub_temp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); + + { + size_t native_items = sub_ref_items + sub_grp_items; + size_t per_stream = + native_items * sizeof(T) + + (fast_keys ? 0 : native_items * sizeof(float)) + + sub_ref_items * sizeof(float) + + (size_t)(sub_batch_cols + 1) * sizeof(int) + ref_cub_temp_bytes + + (run_huge ? sub_grp_items * sizeof(float) : 0) + + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) + : 0) + + (run_huge ? grp_cub_temp_bytes : 0) + + (compute_tie_corr ? (size_t)sub_batch_cols * sizeof(double) : 0) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + (compute_stats + ? 2 * (size_t)n_stats_rows * sub_batch_cols * sizeof(double) + : 0) + + (compute_nnz + ? (size_t)n_stats_rows * sub_batch_cols * sizeof(double) + : 0); + n_streams = clamp_streams_by_budget(n_streams, per_stream, + rmm_available_device_bytes(0.8)); + } + + RmmScratchPool pool; + PinnedRing stage(n_streams, batches.max_items); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + int* d_grp_offsets = pool.alloc(n_groups + 1); + cuda_check(cudaMemcpy(d_grp_offsets, h_grp_offsets, + (size_t)(n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice), + "dense host OVO offsets H2D"); + + int* d_sort_group_ids = nullptr; + if (run_huge) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cuda_check(cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice), + "dense host OVO sort group ids H2D"); + } + + int* d_grp_codes = nullptr; + if (compute_stats) { + std::vector h_grp_codes(n_all_grp, -1); + for (int g = 0; g < n_groups; g++) { + int begin = h_grp_offsets[g]; + int end = h_grp_offsets[g + 1]; + if (begin < 0 || end < begin || end > n_all_grp) { + throw std::runtime_error( + "dense OVO host group offsets are invalid"); + } + std::fill(h_grp_codes.begin() + begin, h_grp_codes.begin() + end, + g); + } + d_grp_codes = pool.alloc(n_all_grp); + cuda_check( + cudaMemcpy(d_grp_codes, h_grp_codes.data(), + (size_t)n_all_grp * sizeof(int), cudaMemcpyHostToDevice), + "dense host OVO group codes H2D"); + } + + struct StreamBuf { + T* ref_native; + T* grp_native; + float* ref_f32; + float* grp_f32; + float* ref_sorted; + int* ref_seg_offsets; + uint8_t* ref_cub_temp; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* grp_cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_sum_sq; + double* sub_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_native = pool.alloc(sub_ref_items); + bufs[s].grp_native = pool.alloc(sub_grp_items); + bufs[s].ref_f32 = + fast_keys ? nullptr : pool.alloc(sub_ref_items); + bufs[s].grp_f32 = + fast_keys ? nullptr : pool.alloc(sub_grp_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); + bufs[s].grp_cub_temp = + run_huge ? pool.alloc(grp_cub_temp_bytes) : nullptr; + bufs[s].ref_tie_sums = + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (run_huge) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = checked_int_product((size_t)n_sort_groups, + (size_t)sub_batch_cols, + "Dense host OVO group segments"); + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + bufs[s].sub_group_sums = + compute_stats + ? pool.alloc((size_t)n_stats_rows * sub_batch_cols) + : nullptr; + bufs[s].sub_group_sum_sq = + compute_stats + ? pool.alloc((size_t)n_stats_rows * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz + ? pool.alloc((size_t)n_stats_rows * sub_batch_cols) + : nullptr; + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + int tpb = UTIL_BLOCK_SIZE; + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "Dense host OVO active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "Dense host OVO active group sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + stage.wait(s); + T* h_ref_stage = stage.template get<0>(s); + T* h_grp_stage = stage.template get<1>(s); + + host_materialize_dense_rows_window(h_X, f_order, n_full_rows, n_cols, + h_ref_row_ids, n_ref, col, sb_cols, + h_ref_stage); + host_materialize_dense_rows_window(h_X, f_order, n_full_rows, n_cols, + h_grp_row_ids, n_all_grp, col, + sb_cols, h_grp_stage); + + cuda_check(cudaMemcpyAsync(buf.ref_native, h_ref_stage, + (size_t)sb_ref_items_actual * sizeof(T), + cudaMemcpyHostToDevice, stream), + "dense host OVO ref H2D"); + cuda_check(cudaMemcpyAsync(buf.grp_native, h_grp_stage, + (size_t)sb_grp_items_actual * sizeof(T), + cudaMemcpyHostToDevice, stream), + "dense host OVO group H2D"); + stage.record(s, stream); + + const float* ref_sub; + const float* grp_sub; + if (fast_keys) { + ref_sub = reinterpret_cast(buf.ref_native); + grp_sub = reinterpret_cast(buf.grp_native); + } else { + unsigned int ref_grid = (unsigned int)std::min( + ((size_t)sb_ref_items_actual + UTIL_BLOCK_SIZE - 1) / + UTIL_BLOCK_SIZE, + 65535u); + dense_block_to_f32_kernel + <<>>( + buf.ref_native, buf.ref_f32, n_ref, sb_cols, true); + CUDA_CHECK_LAST_ERROR(dense_block_to_f32_kernel); + unsigned int grp_grid = (unsigned int)std::min( + ((size_t)sb_grp_items_actual + UTIL_BLOCK_SIZE - 1) / + UTIL_BLOCK_SIZE, + 65535u); + dense_block_to_f32_kernel + <<>>( + buf.grp_native, buf.grp_f32, n_all_grp, sb_cols, true); + CUDA_CHECK_LAST_ERROR(dense_block_to_f32_kernel); + ref_sub = buf.ref_f32; + grp_sub = buf.grp_f32; + } + + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + cub_segmented_sortkeys(buf.ref_cub_temp, ref_cub_temp_bytes, ref_sub, + buf.ref_sorted, sb_ref_items_actual, sb_cols, + buf.ref_seg_offsets, buf.ref_seg_offsets + 1, + stream, "dense host OVO ref segmented sort"); + ref_sub = buf.ref_sorted; + + OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.grp_cub_temp}; + ovo_dispatch_tiers(ref_sub, grp_sub, d_grp_offsets, tier_plan, sc, + d_sort_group_ids, n_sort_groups, grp_cub_temp_bytes, + sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + + cuda_check( + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense host OVO rank_sums D2D copy"); + if (compute_tie_corr) { + cuda_check( + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense host OVO tie_corr D2D copy"); + } + + if (compute_stats) { + cuda_check( + cudaMemsetAsync(buf.sub_group_sums, 0, + (size_t)n_stats_rows * sb_cols * sizeof(double), + stream), + "dense host OVO group sums memset"); + cuda_check( + cudaMemsetAsync(buf.sub_group_sum_sq, 0, + (size_t)n_stats_rows * sb_cols * sizeof(double), + stream), + "dense host OVO group sumsq memset"); + if (compute_nnz) { + cuda_check(cudaMemsetAsync( + buf.sub_group_nnz, 0, + (size_t)n_stats_rows * sb_cols * sizeof(double), + stream), + "dense host OVO group nnz memset"); + } + dense_ovo_group_stats_kernel<<>>( + buf.ref_native, buf.grp_native, d_grp_codes, buf.sub_group_sums, + buf.sub_group_sum_sq, + compute_nnz ? buf.sub_group_nnz : buf.sub_group_sums, n_ref, + n_all_grp, sb_cols, n_groups, compute_nnz); + CUDA_CHECK_LAST_ERROR(dense_ovo_group_stats_kernel); + scatter_cols_2d(group_sums + col, buf.sub_group_sums, n_stats_rows, + n_cols, sb_cols, stream); + scatter_cols_2d(group_sum_sq + col, buf.sub_group_sum_sq, + n_stats_rows, n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(group_nnz + col, buf.sub_group_nnz, + n_stats_rows, n_cols, sb_cols, stream); + } + } + + col += sb_cols; + ++batch_idx; + } + + sync_streams(streams, "dense host OVO streaming"); +} + template static void def_ovr_rank_dense_host_streaming(nb::module_& m) { m.def( @@ -518,12 +854,79 @@ static void def_ovr_rank_dense_host_streaming(nb::module_& m) { "sub_batch_cols"_a = SUB_BATCH_COLS); } +template +static void def_ovo_rank_dense_host_streaming(nb::module_& m) { + m.def( + "ovo_rank_dense_host_streaming", + [](host_array buf, host_array ref_row_ids, + host_array grp_row_ids, host_array grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, + gpu_array_c group_sums, + gpu_array_c group_sum_sq, + gpu_array_c group_nnz, int n_full_rows, int n_cols, + int n_groups, bool f_order, bool compute_tie_corr, bool compute_nnz, + bool compute_stats, int sub_batch_cols) { + nb_require(buf.shape(0) == (size_t)n_full_rows * (size_t)n_cols, + "ovo_rank_host: buf length must be n_rows*n_cols"); + int n_ref = (int)ref_row_ids.shape(0); + int n_all_grp = (int)grp_row_ids.shape(0); + nb_require((int)grp_offsets.shape(0) == n_groups + 1, + "ovo_rank_host: grp_offsets length must be n_groups+1"); + nb_require(rank_sums.ndim() == 2 && tie_corr.ndim() == 2, + "ovo_rank_host: rank_sums/tie_corr must be 2D"); + nb_require((int)rank_sums.shape(0) == n_groups && + (int)rank_sums.shape(1) == n_cols, + "ovo_rank_host: rank_sums shape must be " + "(n_groups, n_cols)"); + nb_require((int)tie_corr.shape(0) == n_groups && + (int)tie_corr.shape(1) == n_cols, + "ovo_rank_host: tie_corr shape must be " + "(n_groups, n_cols)"); + int n_groups_stats = compute_stats ? (int)group_sums.shape(0) : 0; + if (compute_stats) { + nb_require(group_sums.ndim() == 2 && group_sum_sq.ndim() == 2, + "ovo_rank_host: stats outputs must be 2D"); + nb_require(n_groups_stats == n_groups + 1 && + (int)group_sums.shape(1) == n_cols, + "ovo_rank_host: group_sums shape must be " + "(n_groups+1, n_cols)"); + nb_require((int)group_sum_sq.shape(0) == n_groups + 1 && + (int)group_sum_sq.shape(1) == n_cols, + "ovo_rank_host: group_sum_sq shape must be " + "(n_groups+1, n_cols)"); + if (compute_nnz) { + nb_require(group_nnz.ndim() == 2 && + (int)group_nnz.shape(0) == n_groups + 1 && + (int)group_nnz.shape(1) == n_cols, + "ovo_rank_host: group_nnz shape must be " + "(n_groups+1, n_cols)"); + } + } + launch_ovo_rank_dense_host_streaming( + buf.data(), f_order, ref_row_ids.data(), grp_row_ids.data(), + grp_offsets.data(), rank_sums.data(), tie_corr.data(), + compute_stats ? group_sums.data() : nullptr, + compute_stats ? group_sum_sq.data() : nullptr, + compute_nnz ? group_nnz.data() : nullptr, n_full_rows, n_ref, + n_all_grp, n_cols, n_groups, n_groups_stats, compute_tie_corr, + compute_nnz, compute_stats, sub_batch_cols); + }, + "buf"_a, "ref_row_ids"_a, "grp_row_ids"_a, "grp_offsets"_a, + "rank_sums"_a, "tie_corr"_a, "group_sums"_a, "group_sum_sq"_a, + "group_nnz"_a, nb::kw_only(), "n_full_rows"_a, "n_cols"_a, "n_groups"_a, + "f_order"_a, "compute_tie_corr"_a, "compute_nnz"_a, "compute_stats"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); +} + template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; def_ovr_rank_dense_host_streaming(m); def_ovr_rank_dense_host_streaming(m); + def_ovo_rank_dense_host_streaming(m); + def_ovo_rank_dense_host_streaming(m); m.def( "ovo_rank_dense_tiered_unsorted_ref", diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 1d07b780..02059ecc 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -3,10 +3,6 @@ #include #include #include -#include -#include -#include -#include #include #include @@ -16,44 +12,7 @@ #include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR #include "../rmm_scratch.h" // rmm_allocate, RmmScratchPool, ScopedCudaBuffer #include "../sparse_extract/sparse_extract.cuh" // csr_extract_dense* kernels - -// Host thread count for CPU-side CSR passes: hardware concurrency, capped. -static inline int host_worker_count() { - unsigned hw = std::thread::hardware_concurrency(); - return (int)std::min(hw ? hw : 4u, 32u); -} - -// Run fn(chunk, r0, r1) over a partition of [0, n); `chunk` = 0-based worker -// index. fn runs concurrently: read-only shared state, disjoint output ranges -// (keyed by chunk or [r0,r1)). Returns chunks used; serial for small n. -template -static inline int host_parallel_chunks(int n, F fn) { - if (n <= 0) return 0; - int n_threads = host_worker_count(); - if (n_threads <= 1 || n < 4096) { - fn(0, 0, n); - return 1; - } - int chunk = (n + n_threads - 1) / n_threads; - std::vector pool; - pool.reserve(n_threads); - for (int t = 0; t < n_threads; t++) { - int r0 = t * chunk; - if (r0 >= n) break; - int r1 = std::min(n, r0 + chunk); - pool.emplace_back([&fn, t, r0, r1]() { fn(t, r0, r1); }); - } - int used = (int)pool.size(); - for (std::thread& th : pool) th.join(); - return used; -} - -// Run fn(r0, r1) over a partition of [0, n) across hardware threads (serial for -// small n). Concurrent: read-only shared state, disjoint output ranges. -template -static inline void host_parallel_ranges(int n, F fn) { - host_parallel_chunks(n, [&fn](int, int r0, int r1) { fn(r0, r1); }); -} +#include "../streaming/streaming.cuh" constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; @@ -61,38 +20,8 @@ constexpr int N_STREAMS = 4; constexpr int SUB_BATCH_COLS = 64; constexpr int BEGIN_BIT = 0; constexpr int END_BIT = 32; -// Default thread-per-block for utility kernels. -constexpr int UTIL_BLOCK_SIZE = 256; // Scratch slots for warp-level reduction (one slot per warp, 32 warps max). constexpr int WARP_REDUCE_BUF = 32; - -// Stream-count clamps: never use more streams than column batches, nor more -// than the per-stream memory budget allows. -static inline int clamp_streams_by_cols(int n_cols, int sub_batch_cols) { - int n = N_STREAMS; - if (n_cols < n * sub_batch_cols) - n = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - return n; -} - -static inline int clamp_streams_by_budget(int n_streams, - size_t per_stream_bytes, - size_t budget) { - while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) - n_streams--; - return n_streams; -} - -// Scatter a [rows, sb_cols] device sub-batch (row-major doubles, src stride -// sb_cols) into `dst` (stride n_cols). `dst` must point at the dest column -// offset (e.g. out + col). -static inline void scatter_cols_2d(double* dst, const double* src, int rows, - int n_cols, int sb_cols, - cudaStream_t stream) { - cudaMemcpy2DAsync(dst, n_cols * sizeof(double), src, - sb_cols * sizeof(double), sb_cols * sizeof(double), rows, - cudaMemcpyDeviceToDevice, stream); -} // MEDIUM band cap: groups up to this size use unsorted O(n^2) in-group-count // rank (no smem sort). Tier dispatch: make_ovo_tier_plan. constexpr int OVO_MEDIUM_MAX = 512; @@ -202,288 +131,15 @@ static inline size_t wilcoxon_max_smem_per_block() { return cached_smem; } -static inline int checked_cub_items(size_t count, const char* context) { - if (count > (size_t)std::numeric_limits::max()) { - throw std::runtime_error(std::string(context) + - " exceeds CUB int item limit"); - } - return (int)count; -} - -static inline int checked_int_span(size_t count, const char* context) { - if (count > (size_t)std::numeric_limits::max()) { - throw std::runtime_error(std::string(context) + - " exceeds int32 offset limit"); - } - return (int)count; -} - -static inline int checked_int_product(size_t a, size_t b, const char* context) { - if (a != 0 && b > (size_t)std::numeric_limits::max() / a) { - throw std::runtime_error(std::string(context) + - " exceeds int32 item limit"); - } - return (int)(a * b); -} - -// Precompute per-batch CSC column offsets rebased to each batch's ptr_start, -// laid out [n_batches][sub_batch_cols+1], upload once (from `pool`). Avoids a -// per-batch H2D from a transient host buffer. -template -static inline int* precompute_csc_batch_offsets(const IndptrT* h_indptr, - int n_cols, int sub_batch_cols, - int n_batches, - RmmScratchPool& pool, - const char* what) { - std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - for (int b = 0; b < n_batches; b++) { - int col_start = b * sub_batch_cols; - int rem = n_cols - col_start; - int sb = (sub_batch_cols < rem) ? sub_batch_cols : rem; - IndptrT ptr_start = h_indptr[col_start]; - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i <= sb; i++) - off[i] = checked_int_span( - (size_t)(h_indptr[col_start + i] - ptr_start), what); - } - int* d_all_offsets = - pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); - cudaMemcpy(d_all_offsets, h_all_offsets.data(), - h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); - return d_all_offsets; -} - // Max per-batch nnz: a batch is sorted in one CUB segmented call (int32 item // count) and addressed with int offsets, so it must stay below INT_MAX. -constexpr size_t SAFE_BATCH_NNZ = 2000000000; // < INT_MAX - -// Halve sub_batch_cols until the densest window holds <= cap nonzeros, keeping -// every batch's nnz within int32 for CUB and bounding per-stream transpose/sort -// scratch. col_nnz(i) = nnz of column i. Worst case returns 1 (single column, -// nnz <= n_rows). -template -static inline int cap_sub_batch_by_nnz(int n_cols, int sub_batch_cols, - size_t cap, ColNnz col_nnz) { - if (cap < 1) cap = 1; - auto max_window = [&](int s) { - size_t mx = 0; - for (int c = 0; c < n_cols; c += s) { - int e = std::min(c + s, n_cols); - size_t sum = 0; - for (int i = c; i < e; i++) sum += col_nnz(i); - if (sum > mx) mx = sum; - } - return mx; - }; - while (sub_batch_cols > 1 && max_window(sub_batch_cols) > cap) - sub_batch_cols = (sub_batch_cols + 1) / 2; - return sub_batch_cols; -} - -// RAII guard for cudaHostRegister: unregisters on scope exit (incl. exception -// unwind), preventing leaked host pinning on stream-sync failures. -struct HostRegisterGuard { - void* ptr = nullptr; - - HostRegisterGuard() = default; - HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0, - bool best_effort = false) { - if (p && bytes > 0) { - cudaError_t err = cudaHostRegister(p, bytes, flags); - if (err != cudaSuccess) { - // Already-registered = owned elsewhere; use it without - // unregistering. Other failures make mapped reads unsafe, so - // surface them -- unless best_effort (pin is only a speedup; - // unpinned H2D still works). - if (err == cudaErrorHostMemoryAlreadyRegistered || - best_effort) { - cudaGetLastError(); // clear sticky error flag - } else { - throw std::runtime_error( - std::string("cudaHostRegister failed (") + - std::to_string((size_t)bytes) + - " bytes, flags=" + std::to_string(flags) + - "): " + cudaGetErrorString(err)); - } - } else { - ptr = p; - } - } - } - ~HostRegisterGuard() { - if (ptr) cudaHostUnregister(ptr); - } - HostRegisterGuard(const HostRegisterGuard&) = delete; - HostRegisterGuard& operator=(const HostRegisterGuard&) = delete; - HostRegisterGuard(HostRegisterGuard&& other) noexcept : ptr(other.ptr) { - other.ptr = nullptr; - } - HostRegisterGuard& operator=(HostRegisterGuard&& other) noexcept { - if (this != &other) { - if (ptr) cudaHostUnregister(ptr); - ptr = other.ptr; - other.ptr = nullptr; - } - return *this; - } -}; - -// RAII for CUDA streams/events: reclaim on every path (incl. exception unwind). -// Stream dtor SYNCHRONIZES before destroying. CRITICAL ordering: declare the -// RmmScratchPool BEFORE these guards so streams (destroyed first) drain -// in-flight kernels before the pool (destroyed last) frees the scratch they -// read. -struct ScopedCudaStream { - cudaStream_t stream = nullptr; - - ScopedCudaStream() = default; - explicit ScopedCudaStream(unsigned int flags) { - cuda_check(cudaStreamCreateWithFlags(&stream, flags), - "cudaStreamCreateWithFlags"); - } - ~ScopedCudaStream() { - if (stream) { - cudaStreamSynchronize(stream); // drain before teardown - cudaStreamDestroy(stream); - } - } - operator cudaStream_t() const { - return stream; - } - cudaStream_t get() const { - return stream; - } - ScopedCudaStream(const ScopedCudaStream&) = delete; - ScopedCudaStream& operator=(const ScopedCudaStream&) = delete; -}; - -struct ScopedCudaStreams { - std::vector streams; - - // `flags` is explicit so call sites keep their original stream semantics. - ScopedCudaStreams(int n, unsigned int flags) { - streams.reserve(n > 0 ? (size_t)n : 0); - for (int i = 0; i < n; ++i) { - cudaStream_t s = nullptr; - cudaError_t err = cudaStreamCreateWithFlags(&s, flags); - if (err != cudaSuccess) { - // dtor won't run on ctor throw; reclaim what we made. - for (cudaStream_t prev : streams) { - cudaStreamSynchronize(prev); - cudaStreamDestroy(prev); - } - throw std::runtime_error( - std::string("cudaStreamCreateWithFlags failed: ") + - cudaGetErrorString(err)); - } - streams.push_back(s); - } - } - ~ScopedCudaStreams() { - for (cudaStream_t s : streams) { - if (!s) continue; - cudaStreamSynchronize(s); // drain before teardown - cudaStreamDestroy(s); - } - } - cudaStream_t operator[](int i) const { - return streams[i]; - } - int size() const { - return (int)streams.size(); - } - ScopedCudaStreams(const ScopedCudaStreams&) = delete; - ScopedCudaStreams& operator=(const ScopedCudaStreams&) = delete; -}; - -// Drain every stream, surfacing the first async error with a context label. -static inline void sync_streams(const ScopedCudaStreams& streams, - const char* what) { - for (int i = 0; i < streams.size(); ++i) { - cudaError_t err = cudaStreamSynchronize(streams[i]); - if (err != cudaSuccess) - throw std::runtime_error(std::string("CUDA error in ") + what + - ": " + cudaGetErrorString(err)); - } -} - -struct ScopedCudaEvent { - cudaEvent_t event = nullptr; - - ScopedCudaEvent() = default; - explicit ScopedCudaEvent(unsigned int flags) { - cuda_check(cudaEventCreateWithFlags(&event, flags), - "cudaEventCreateWithFlags"); - } - ~ScopedCudaEvent() { - if (event) cudaEventDestroy(event); - } - void record(cudaStream_t stream) { - cuda_check(cudaEventRecord(event, stream), "cudaEventRecord"); - } - cudaEvent_t get() const { - return event; - } - ScopedCudaEvent(const ScopedCudaEvent&) = delete; - ScopedCudaEvent& operator=(const ScopedCudaEvent&) = delete; -}; +constexpr size_t SAFE_BATCH_NNZ = STREAMING_SAFE_BATCH_NNZ; static inline int round_up_to_warp(int n) { int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; } -// Per-stream pinned host staging (f32 vals + int32 cols) with a per-slot event, -// so a CPU gather into slot s overlaps GPU compute: wait(s) blocks only until -// slot s's prior H2D drained, not the whole pipeline. -struct HostStagingRing { - std::vector> vals; - std::vector> cols; - std::vector pin_v, pin_c; - std::vector evt; - std::vector used; - HostStagingRing(int n_streams, size_t nnz) - : vals(n_streams), - cols(n_streams), - pin_v(n_streams), - pin_c(n_streams), - evt(n_streams, nullptr), - used(n_streams, 0) { - size_t n = nnz ? nnz : 1; - for (int s = 0; s < n_streams; s++) { - vals[s].reset(new float[n]); - cols[s].reset(new int[n]); - pin_v[s] = HostRegisterGuard(vals[s].get(), n * sizeof(float)); - pin_c[s] = HostRegisterGuard(cols[s].get(), n * sizeof(int)); - cuda_check( - cudaEventCreateWithFlags(&evt[s], cudaEventDisableTiming), - "HostStagingRing event create"); - } - } - ~HostStagingRing() { - for (cudaEvent_t e : evt) - if (e) cudaEventDestroy(e); - } - void wait(int s) { - if (used[s]) - cuda_check(cudaEventSynchronize(evt[s]), "HostStagingRing reuse"); - } - void record(int s, cudaStream_t stream) { - cuda_check(cudaEventRecord(evt[s], stream), "HostStagingRing record"); - used[s] = true; - } - HostStagingRing(const HostStagingRing&) = delete; - HostStagingRing& operator=(const HostStagingRing&) = delete; -}; - -/** Fill linear segment offsets [0, stride, ..., n_segments*stride] on-device. - */ -__global__ void fill_linear_offsets_kernel(int* __restrict__ out, - int n_segments, int stride) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i <= n_segments) out[i] = i * stride; -} - /** Per-row stats codes for a pack of K groups. From pack_grp_offsets (size K+1, * relative to pack start), write stats_codes[r] = base_slot + group_idx(r) via * binary search over the K+1 offsets. */ @@ -504,57 +160,6 @@ __global__ void fill_pack_stats_codes_kernel( stats_codes[r] = base_slot + lo; } -/** Rebase a slice of indptr: out[i] = indptr[col+i] - indptr[col]. Grid-strided - * (arbitrary `count`). Templated so 64-bit global indptrs produce 32-bit - * pack-local indptrs (per-pack nnz fits int32 via the memory budget). */ -template -__global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, - IdxOut* __restrict__ out, int col, - int count) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); -} - -// Threaded host gather of selected rows into compact staging (f32 vals + int32 -// cols) at disjoint per-row offsets (compact_indptr - base) -> race-free. -// No-pin alternative to the mapped gather kernel: only the compacted slice -// crosses the bus. -template -static void host_gather_rows_compact(const InT* h_data, const IndexT* h_indices, - const IndptrT* h_indptr, - const int* row_ids, - const CompactT* compact_indptr, - CompactT base, int n_target, - float* stage_vals, int* stage_cols) { - host_parallel_ranges(n_target, [&](int i0, int i1) { - for (int i = i0; i < i1; i++) { - int r = row_ids[i]; - IndptrT rs = h_indptr[r]; - int nnz = (int)(h_indptr[r + 1] - rs); - size_t ds = (size_t)(compact_indptr[i] - base); - for (int k = 0; k < nnz; k++) { - stage_vals[ds + k] = (float)h_data[rs + k]; - stage_cols[ds + k] = (int)h_indices[rs + k]; - } - } - }); -} - -// Threaded host cast-copy of a contiguous nnz slice into staging (f32 + int32). -// CSC analogue of host_gather_rows_compact: contiguous column batch, no gather. -// nnz fits int32 (batch-bounded). -template -static void host_cast_copy_slice(const InT* h_data, const IndexT* h_indices, - size_t start, int nnz, float* stage_vals, - int* stage_cols) { - host_parallel_ranges(nnz, [&](int k0, int k1) { - for (int k = k0; k < k1; k++) { - stage_vals[k] = (float)h_data[start + k]; - stage_cols[k] = (int)h_indices[start + k]; - } - }); -} - // Per-group stats over an already-compact CSR (accumulate half of the mapped // gather kernel, decoupled for host-staged data). slot = stats_codes[r] or // fixed_slot; slot outside [0,n_groups_stats) is skipped. @@ -578,14 +183,3 @@ __global__ void csr_compact_accumulate_kernel( atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); } } - -/** Fill linear segment offsets [0, stride, ...] on the supplied stream (avoids - * serializing multi-stream pipelines). */ -static inline void upload_linear_offsets(int* d_offsets, int n_segments, - int stride, cudaStream_t stream) { - int count = n_segments + 1; - int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - fill_linear_offsets_kernel<<>>( - d_offsets, n_segments, stride); - CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); -} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 37768d2b..4eae5a45 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -13,14 +13,10 @@ static void ovo_streaming_csr_impl( int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - // Cap sub_batch_cols so the group slab (n_all_grp × sub_batch_cols, one CUB - // sort) stays within int32; n_all_grp (cell count) drives the cap. - { - size_t cap = n_all_grp > 0 ? SAFE_BATCH_NNZ / (size_t)n_all_grp - : (size_t)sub_batch_cols; - if (cap < 1) cap = 1; - if ((size_t)sub_batch_cols > cap) sub_batch_cols = (int)cap; - } + DenseColumnBatchPlan group_batches = plan_dense_column_batches( + n_all_grp, n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "OVO device CSR group sub-batch"); + sub_batch_cols = group_batches.sub_batch_cols; std::vector h_offsets(n_groups + 1); cuda_check(cudaMemcpy(h_offsets.data(), grp_offsets, @@ -248,15 +244,10 @@ static void ovo_streaming_csc_impl( int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - // Cap sub_batch_cols so both slabs (n_ref× and n_all_grp× sub_batch_cols, - // each one CUB sort) stay within int32; these cell counts drive the cap. - { - size_t max_rows = (size_t)std::max(n_ref, n_all_grp); - size_t cap = - max_rows > 0 ? SAFE_BATCH_NNZ / max_rows : (size_t)sub_batch_cols; - if (cap < 1) cap = 1; - if ((size_t)sub_batch_cols > cap) sub_batch_cols = (int)cap; - } + DenseColumnBatchPlan batches = plan_dense_column_batches( + std::max(n_ref, n_all_grp), n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "OVO device CSC sub-batch"); + sub_batch_cols = batches.sub_batch_cols; std::vector h_offsets(n_groups + 1); cuda_check(cudaMemcpy(h_offsets.data(), grp_offsets, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 00364e17..6ad19806 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -1,5 +1,101 @@ #pragma once +struct OvoHostCsrPack { + int first; + int end; + int n_rows; + size_t nnz; + int sb_cols; +}; + +struct OvoHostCsrPackPlan { + std::vector packs; + int max_pack_rows = 0; + size_t max_pack_nnz = 0; + int max_pack_K = 0; + int max_pack_items = 0; + int max_pack_sb_cols = 0; + size_t max_sub_items = 0; +}; + +template +static OvoHostCsrPackPlan plan_ovo_host_csr_packs( + const int* h_grp_offsets, const IndptrT* h_grp_indptr_compact, + int n_all_grp, int n_test, int n_cols, int n_ref, int sub_batch_cols) { + OvoHostCsrPackPlan plan; + plan.max_pack_sb_cols = sub_batch_cols; + + int target_packs = N_STREAMS; + int target_rows = (n_all_grp + target_packs - 1) / target_packs; + if (target_rows < 1) target_rows = 1; + size_t budget_cap_rows = GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; + if ((size_t)target_rows > budget_cap_rows) + target_rows = (int)budget_cap_rows; + + constexpr size_t SAFE_PACK_NNZ = 1500000000; // < INT_MAX, CUB-safe + size_t pack_nnz_cap = SAFE_PACK_NNZ; + { + int target_streams = std::min(N_STREAMS, n_test); + if (target_streams < 1) target_streams = 1; + size_t dev_budget = rmm_available_device_bytes(0.9); + size_t ref_bytes = (size_t)n_ref * (size_t)n_cols * sizeof(float); + size_t reserve = (size_t)target_streams * OVO_PACK_FIXED_PER_STREAM; + size_t grp_avail = dev_budget > ref_bytes ? dev_budget - ref_bytes : 0; + size_t data_avail = grp_avail > reserve ? grp_avail - reserve : 0; + size_t cap = data_avail / ((size_t)target_streams * 2 * sizeof(float)); + if (cap < OVO_MIN_PACK_NNZ) cap = OVO_MIN_PACK_NNZ; + if (cap < pack_nnz_cap) pack_nnz_cap = cap; + } + + int cur_first = 0; + int cur_rows = 0; + size_t cur_nnz = 0; + for (int g = 0; g < n_test; g++) { + int n_g = h_grp_offsets[g + 1] - h_grp_offsets[g]; + size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - + h_grp_indptr_compact[h_grp_offsets[g]]); + int new_rows = cur_rows + n_g; + bool can_add = (cur_rows == 0) || (new_rows <= target_rows && + cur_nnz + nnz_g <= pack_nnz_cap); + if (!can_add) { + size_t sb_size = std::min( + (size_t)n_cols, GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + plan.packs.push_back( + {cur_first, g, cur_rows, cur_nnz, (int)sb_size}); + cur_first = g; + cur_rows = n_g; + cur_nnz = nnz_g; + } else { + cur_rows = new_rows; + cur_nnz += nnz_g; + } + } + if (cur_rows > 0) { + size_t sb_size = std::min((size_t)n_cols, + GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + plan.packs.push_back( + {cur_first, n_test, cur_rows, cur_nnz, (int)sb_size}); + } + + for (const OvoHostCsrPack& pk : plan.packs) { + int K = pk.end - pk.first; + if (pk.n_rows > plan.max_pack_rows) plan.max_pack_rows = pk.n_rows; + if (pk.nnz > plan.max_pack_nnz) plan.max_pack_nnz = pk.nnz; + if (K > plan.max_pack_K) plan.max_pack_K = K; + int pack_items = + checked_int_product((size_t)pk.n_rows, (size_t)pk.sb_cols, + "OVO host CSR pack dense slab"); + if (pack_items > plan.max_pack_items) plan.max_pack_items = pack_items; + checked_int_span(pk.nnz, "OVO host CSR pack compacted nnz"); + if (pk.sb_cols > plan.max_pack_sb_cols) + plan.max_pack_sb_cols = pk.sb_cols; + } + plan.max_sub_items = (size_t)plan.max_pack_items; + return plan; +} + /** * Host-streaming CSC OVO pipeline: CSC on host, only each column sub-batch is * sent to GPU; row maps + group offsets uploaded once; results written back @@ -17,16 +113,15 @@ static void ovo_streaming_csc_host_impl( // Cap sub_batch_cols so neither the dense ref/group slabs (rows × // sub_batch_cols, one CUB call) nor per-batch nnz exceed int32. - { - size_t max_rows = (size_t)std::max(n_ref, n_all_grp); - size_t dense_cap = - max_rows > 0 ? SAFE_BATCH_NNZ / max_rows : (size_t)sub_batch_cols; - if (dense_cap < 1) dense_cap = 1; - if ((size_t)sub_batch_cols > dense_cap) sub_batch_cols = (int)dense_cap; - sub_batch_cols = cap_sub_batch_by_nnz( - n_cols, sub_batch_cols, SAFE_BATCH_NNZ, - [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); - } + DenseColumnBatchPlan dense_batches = plan_dense_column_batches( + std::max(n_ref, n_all_grp), n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "OVO host CSC dense sub-batch"); + sub_batch_cols = dense_batches.sub_batch_cols; + size_t sparse_cap = SAFE_BATCH_NNZ; + ColumnBatchPlan batches = + plan_csc_column_batches(h_indptr, n_cols, sub_batch_cols, sparse_cap, + "OVO host CSC rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; auto t1 = make_ovo_tier_plan(h_grp_offsets, n_groups); int max_grp_size = t1.max_grp_size; @@ -62,24 +157,22 @@ static void ovo_streaming_csc_host_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - // Max nnz across any sub-batch for sparse transfer buffer sizing - size_t max_nnz = 0; - for (int c = 0; c < n_cols; c += sub_batch_cols) { - int sb = std::min(sub_batch_cols, n_cols - c); - size_t nnz = (size_t)(h_indptr[c + sb] - h_indptr[c]); - if (nnz > max_nnz) max_nnz = nnz; - } + size_t max_nnz = batches.max_nnz; + constexpr size_t window_value_bytes = + sizeof(WilcoxonSparseWindowDTypes::value_type); // Clamp streams so per-stream scratch fits the budget: dense slabs scale // with cell counts, so a fixed N_STREAMS would OOM at scale. { size_t per_stream = - max_nnz * (sizeof(InT) + sizeof(float) + sizeof(IndexT)) + - 2 * sub_ref_items * sizeof(float) + - (run_huge ? 2 : 1) * sub_grp_items * sizeof(float) + - 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + - (compute_nnz ? 2 : 1) * (size_t)n_groups_stats * sub_batch_cols * - sizeof(double) + + sparse_window_nnz_bytes(max_nnz) + + 2 * sub_ref_items * window_value_bytes + + (run_huge ? 2 : 1) * sub_grp_items * window_value_bytes + + sparse_window_accum_bytes( + 2 * (size_t)n_groups * sub_batch_cols) + + (compute_nnz ? 2 : 1) * + sparse_window_accum_bytes( + (size_t)n_groups_stats * sub_batch_cols) + cub_temp_bytes; size_t budget = rmm_available_device_bytes(0.8); n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); @@ -87,14 +180,12 @@ static void ovo_streaming_csc_host_impl( // pool first: streams drain before it frees their scratch (RAII order). RmmScratchPool pool; - // No full-matrix page-lock: each column batch is cast-copied into small - // per-stream pinned staging (f32 vals + int32 cols) and bulk-H2D'd. + // Bounded staging avoids page-locking huge host CSC arrays and gives every + // dtype/index combination the same device footprint. + HostStagingRing stage(n_streams, max_nnz); ScopedCudaStreams streams(n_streams, cudaStreamDefault); - int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - int* d_all_offsets = precompute_csc_batch_offsets( - h_indptr, n_cols, sub_batch_cols, n_batches, pool, - "OVO host CSC rebased column offsets"); + int* d_all_offsets = upload_batch_offsets(batches, pool); // Row maps + group offsets + stats codes (uploaded once) int* d_ref_row_map = pool.alloc(n_rows); @@ -170,9 +261,6 @@ static void ovo_streaming_csc_host_impl( } } - // Per-stream pinned staging for the contiguous column-batch cast-copy. - HostStagingRing stage(n_streams, max_nnz); - int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); bool cast_use_gmem = false; @@ -202,10 +290,10 @@ static void ovo_streaming_csc_host_impl( // the next copy overlap compute. stage.wait(s); host_cast_copy_slice(h_data, h_indices, (size_t)ptr_start, nnz_i, - stage.vals[s].get(), stage.cols[s].get()); - cudaMemcpyAsync(buf.d_sparse_data_f32, stage.vals[s].get(), + stage.get<0>(s), stage.get<1>(s)); + cudaMemcpyAsync(buf.d_sparse_data_f32, stage.get<0>(s), nnz * sizeof(float), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(buf.d_sparse_indices, stage.cols[s].get(), + cudaMemcpyAsync(buf.d_sparse_indices, stage.get<1>(s), nnz * sizeof(int), cudaMemcpyHostToDevice, stream); stage.record(s, stream); int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); @@ -319,101 +407,15 @@ static void ovo_streaming_csr_host_impl( h_grp_indptr_compact[i + 1] = h_grp_indptr_compact[i] + nnz_i; } - // Build packs (same rule as grp_impl, but uses compacted indptr) - struct Pack { - int first; - int end; - int n_rows; - size_t nnz; - int sb_cols; - }; - std::vector packs; - int max_pack_rows = 0; - size_t max_pack_nnz = 0; - int max_pack_K = 0; - int max_pack_items = 0; - int max_pack_sb_cols = sub_batch_cols; - { - int target_packs = N_STREAMS; - int target_rows = (n_all_grp + target_packs - 1) / target_packs; - if (target_rows < 1) target_rows = 1; - size_t budget_cap_rows = - GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; - if ((size_t)target_rows > budget_cap_rows) - target_rows = (int)budget_cap_rows; - // Bound each pack's compacted nnz < INT_MAX (feeds int32 CUB item - // counts + offsets); splits dense groups across more packs. - constexpr size_t SAFE_PACK_NNZ = 1500000000; // < INT_MAX, CUB-safe - - // Budget-aware pack-nnz cap: shrink packs so ~N_STREAMS of them plus - // the sorted ref cache and per-stream fixed scratch fit in 90% of free - // device memory. On a tight GPU this yields more, smaller packs (the - // per-stream device buffers then fit; the Phase-2 stream clamp sets the - // final stream count); on a big GPU it stays at the int32-safe cap, so - // large-GPU pack layout (and timing) is unchanged. Only the per-stream - // pack buffers scale with this; the ref cache is bounded separately. - size_t pack_nnz_cap = SAFE_PACK_NNZ; - { - int target_streams = std::min(N_STREAMS, n_test); - if (target_streams < 1) target_streams = 1; - size_t dev_budget = rmm_available_device_bytes(0.9); - size_t ref_bytes = (size_t)n_ref * (size_t)n_cols * sizeof(float); - size_t reserve = (size_t)target_streams * OVO_PACK_FIXED_PER_STREAM; - size_t grp_avail = - dev_budget > ref_bytes ? dev_budget - ref_bytes : 0; - size_t data_avail = grp_avail > reserve ? grp_avail - reserve : 0; - size_t cap = - data_avail / ((size_t)target_streams * 2 * sizeof(float)); - if (cap < OVO_MIN_PACK_NNZ) cap = OVO_MIN_PACK_NNZ; - if (cap < pack_nnz_cap) pack_nnz_cap = cap; - } - - int cur_first = 0; - int cur_rows = 0; - size_t cur_nnz = 0; - for (int g = 0; g < n_test; g++) { - int n_g = h_grp_offsets[g + 1] - h_grp_offsets[g]; - size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - - h_grp_indptr_compact[h_grp_offsets[g]]); - int new_rows = cur_rows + n_g; - bool can_add = (cur_rows == 0) || (new_rows <= target_rows && - cur_nnz + nnz_g <= pack_nnz_cap); - if (!can_add) { - size_t sb_size = - std::min((size_t)n_cols, - GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); - if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; - packs.push_back( - {cur_first, g, cur_rows, cur_nnz, (int)sb_size}); - cur_first = g; - cur_rows = n_g; - cur_nnz = nnz_g; - } else { - cur_rows = new_rows; - cur_nnz += nnz_g; - } - } - if (cur_rows > 0) { - size_t sb_size = std::min( - (size_t)n_cols, GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); - if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; - packs.push_back( - {cur_first, n_test, cur_rows, cur_nnz, (int)sb_size}); - } - } - for (const Pack& pk : packs) { - int K = pk.end - pk.first; - if (pk.n_rows > max_pack_rows) max_pack_rows = pk.n_rows; - if (pk.nnz > max_pack_nnz) max_pack_nnz = pk.nnz; - if (K > max_pack_K) max_pack_K = K; - int pack_items = - checked_int_product((size_t)pk.n_rows, (size_t)pk.sb_cols, - "OVO host CSR pack dense slab"); - if (pack_items > max_pack_items) max_pack_items = pack_items; - checked_int_span(pk.nnz, "OVO host CSR pack compacted nnz"); - if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; - } - size_t max_sub_items = (size_t)max_pack_items; + OvoHostCsrPackPlan pack_plan = plan_ovo_host_csr_packs( + h_grp_offsets, h_grp_indptr_compact.data(), n_all_grp, n_test, n_cols, + n_ref, sub_batch_cols); + const std::vector& packs = pack_plan.packs; + int max_pack_rows = pack_plan.max_pack_rows; + size_t max_pack_nnz = pack_plan.max_pack_nnz; + int max_pack_K = pack_plan.max_pack_K; + int max_pack_sb_cols = pack_plan.max_pack_sb_cols; + size_t max_sub_items = pack_plan.max_sub_items; if (max_pack_rows == 0) return; RmmScratchPool pool; @@ -435,12 +437,7 @@ static void ovo_streaming_csr_host_impl( // Pinned staging for the reference gather (compacted f32 vals + int32 // cols). Uninitialized: the gather overwrites it, so skip a multi-GB zero. size_t ref_stage_n = ref_nnz ? (size_t)ref_nnz : 1; - std::unique_ptr h_ref_stage_vals(new float[ref_stage_n]); - std::unique_ptr h_ref_stage_cols(new int[ref_stage_n]); - HostRegisterGuard pin_ref_vals(h_ref_stage_vals.get(), - ref_stage_n * sizeof(float)); - HostRegisterGuard pin_ref_cols(h_ref_stage_cols.get(), - ref_stage_n * sizeof(int)); + PinnedRing ref_stage(1, ref_stage_n); // Upload row_ids + compacted indptrs + group boundaries int* d_ref_row_ids = pool.alloc(n_ref); @@ -502,16 +499,16 @@ static void ovo_streaming_csr_host_impl( if (n_ref > 0 && ref_nnz > 0) { host_gather_rows_compact(h_data, h_indices, h_indptr, h_ref_row_ids, h_ref_indptr_compact.data(), 0, n_ref, - h_ref_stage_vals.get(), - h_ref_stage_cols.get()); - cuda_check(cudaMemcpyAsync(d_ref_data_f32, h_ref_stage_vals.get(), + ref_stage.get<0>(0), ref_stage.get<1>(0)); + cuda_check(cudaMemcpyAsync(d_ref_data_f32, ref_stage.get<0>(0), (size_t)ref_nnz * sizeof(float), cudaMemcpyHostToDevice, ref_stream), "OVO host CSR ref staged vals H2D"); - cuda_check(cudaMemcpyAsync(d_ref_indices, h_ref_stage_cols.get(), + cuda_check(cudaMemcpyAsync(d_ref_indices, ref_stage.get<1>(0), (size_t)ref_nnz * sizeof(int), cudaMemcpyHostToDevice, ref_stream), "OVO host CSR ref staged cols H2D"); + ref_stage.record(0, ref_stream); if (compute_sums || compute_nnz) { csr_compact_accumulate_kernel<<>>( @@ -575,6 +572,8 @@ static void ovo_streaming_csr_host_impl( int max_pack_kernel_seg = checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, "OVO host CSR pack segment buffer"); + constexpr size_t window_value_bytes = + sizeof(WilcoxonSparseWindowDTypes::value_type); // Clamp streams to the device-memory budget (90% of free). The per-stream // pack buffers + dense slabs dominate device use, so a fixed stream count @@ -584,17 +583,18 @@ static void ovo_streaming_csr_host_impl( // means less gather/compute overlap, not a re-stream. { size_t per_stream = - 2 * max_pack_nnz * sizeof(float) // grp data + idx - + (size_t)(max_pack_rows + 1) * sizeof(int) // grp indptr - + (size_t)max_pack_rows * sizeof(int) // stats codes - + (size_t)(max_pack_K + 1) * sizeof(int) // pack grp offsets - + max_sub_items * sizeof(float) // grp dense - + 2 * (size_t)max_pack_K * max_pack_sb_cols * - sizeof(double) // rank+tie - + (size_t)max_pack_sb_cols * sizeof(double) // ref tie + sparse_window_nnz_bytes(max_pack_nnz) + + (size_t)(max_pack_rows + 1) * sizeof(int) // grp indptr + + (size_t)max_pack_rows * sizeof(int) // stats codes + + (size_t)(max_pack_K + 1) * sizeof(int) // pack grp offsets + + max_sub_items * window_value_bytes // grp dense + + sparse_window_accum_bytes( + 2 * (size_t)max_pack_K * max_pack_sb_cols) // rank+tie + + sparse_window_accum_bytes( + (size_t)max_pack_sb_cols) // ref tie + (may_need_cub - ? max_sub_items * sizeof(float) // grp sorted + ? max_sub_items * window_value_bytes // grp sorted + (size_t)max_pack_K * sizeof(int) // sort ids + 2 * (size_t)max_pack_kernel_seg * sizeof(int) // segs + cub_grp_bytes // cub temp @@ -660,7 +660,7 @@ static void ovo_streaming_csr_host_impl( int stage_slot = 0; for (int p = 0; p < (int)packs.size(); p++) { - const Pack& pack = packs[p]; + const OvoHostCsrPack& pack = packs[p]; int K = pack.end - pack.first; if (K == 0 || pack.n_rows == 0) continue; OvoTierPlan pack_t1 = make_ovo_tier_plan(h_grp_offsets + pack.first, K); @@ -742,17 +742,17 @@ static void ovo_streaming_csr_host_impl( h_data, h_indices, h_indptr, h_grp_row_ids + row_start + rb0, h_grp_indptr_compact.data() + row_start + rb0, blk_base, - rb1 - rb0, stage.vals[slot].get(), stage.cols[slot].get()); - cuda_check(cudaMemcpyAsync(buf.d_grp_data_f32 + dev_off, - stage.vals[slot].get(), - blk_nnz * sizeof(float), - cudaMemcpyHostToDevice, stream), - "OVO host CSR pack staged vals H2D"); - cuda_check(cudaMemcpyAsync(buf.d_grp_indices + dev_off, - stage.cols[slot].get(), - blk_nnz * sizeof(int), - cudaMemcpyHostToDevice, stream), - "OVO host CSR pack staged cols H2D"); + rb1 - rb0, stage.get<0>(slot), stage.get<1>(slot)); + cuda_check( + cudaMemcpyAsync(buf.d_grp_data_f32 + dev_off, + stage.get<0>(slot), blk_nnz * sizeof(float), + cudaMemcpyHostToDevice, stream), + "OVO host CSR pack staged vals H2D"); + cuda_check( + cudaMemcpyAsync(buf.d_grp_indices + dev_off, + stage.get<1>(slot), blk_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream), + "OVO host CSR pack staged cols H2D"); stage.record(slot, stream); rb0 = rb1; } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index f9495b65..a7358246 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -24,6 +24,40 @@ __global__ void build_huge_seg_offsets_kernel( ends[idx] = base + grp_offsets[g + 1]; } +template +__global__ void dense_ovo_group_stats_kernel( + const T* __restrict__ ref_dense, const T* __restrict__ grp_dense, + const int* __restrict__ grp_codes, double* __restrict__ group_sums, + double* __restrict__ group_sum_sq, double* __restrict__ group_nnz, + int n_ref, int n_all_grp, int sb_cols, int n_groups, bool compute_nnz) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int ref_slot = n_groups; + const T* ref_col = ref_dense + (size_t)col * n_ref; + const T* grp_col = grp_dense + (size_t)col * n_all_grp; + + for (int row = threadIdx.x; row < n_ref; row += blockDim.x) { + double v = (double)ref_col[row]; + atomicAdd(&group_sums[(size_t)ref_slot * sb_cols + col], v); + atomicAdd(&group_sum_sq[(size_t)ref_slot * sb_cols + col], v * v); + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)ref_slot * sb_cols + col], 1.0); + } + } + + for (int row = threadIdx.x; row < n_all_grp; row += blockDim.x) { + int g = grp_codes[row]; + if (g < 0 || g >= n_groups) continue; + double v = (double)grp_col[row]; + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + atomicAdd(&group_sum_sq[(size_t)g * sb_cols + col], v * v); + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } +} + /** * Sizing knobs for LARGE-band dispatch: largest group fits in smem -> fused * bitonic-sort + binary-search kernel per block; else fall back to HUGE band diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 16788103..d435df6c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -16,26 +16,21 @@ static void ovr_sparse_csc_host_streaming_impl( // Bound each batch's nnz: CUB item counts stay within int32 + per-stream // sort buffers fit the budget (column counts free from CSC indptr). + size_t cap = SAFE_BATCH_NNZ; { constexpr size_t BYTES_PER_NNZ = sizeof(InT) + 2 * sizeof(float) + 2 * sizeof(IndexT) + 8; - size_t cap = SAFE_BATCH_NNZ; size_t mem_cap = rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; - sub_batch_cols = cap_sub_batch_by_nnz( - n_cols, sub_batch_cols, cap, - [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); } + ColumnBatchPlan batches = + plan_csc_column_batches(h_indptr, n_cols, sub_batch_cols, cap, + "OVR host CSC rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); - - size_t max_nnz = 0; - for (int col = 0; col < n_cols; col += sub_batch_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); - if (nnz > max_nnz) max_nnz = nnz; - } + size_t max_nnz = batches.max_nnz; size_t cub_temp_bytes = 0; if (max_nnz > 0) { @@ -47,13 +42,21 @@ static void ovr_sparse_csc_host_streaming_impl( // pool first: streams drain before it frees their scratch (see guard doc). RmmScratchPool pool; - // Pin host inputs before streams: on exception unwind streams drain before - // buffers are unregistered (mirrors safe CSR order). size_t total_nnz = (size_t)h_indptr[n_cols]; - HostRegisterGuard _pin_data(const_cast(h_data), - total_nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - total_nnz * sizeof(IndexT)); + size_t direct_pin_bytes = total_nnz * (sizeof(InT) + sizeof(IndexT)); + bool use_bounded_stage = + direct_pin_bytes > HOST_STREAMING_DIRECT_PIN_LIMIT_BYTES; + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + std::unique_ptr> stage; + if (use_bounded_stage) { + stage.reset(new PinnedRing(n_streams, max_nnz)); + } else { + pin_data = HostRegisterGuard(const_cast(h_data), + total_nnz * sizeof(InT)); + pin_indices = HostRegisterGuard(const_cast(h_indices), + total_nnz * sizeof(IndexT)); + } ScopedCudaStreams streams(n_streams, cudaStreamDefault); int* d_group_codes = pool.alloc(n_rows); double* d_group_sizes = pool.alloc(n_groups); @@ -99,10 +102,7 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpyHostToDevice); // Pre-compute rebased per-batch offsets, upload once (no per-batch H2D). - int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - int* d_all_offsets = precompute_csc_batch_offsets( - h_indptr, n_cols, sub_batch_cols, n_batches, pool, - "OVR host CSC rebased column offsets"); + int* d_all_offsets = upload_batch_offsets(batches, pool); int tpb = UTIL_BLOCK_SIZE; bool rank_use_gmem = false; @@ -137,8 +137,24 @@ static void ovr_sparse_csc_host_streaming_impl( int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), "OVR host CSC active batch nnz"); - // H2D: this column range's sparse data (native dtype). - if (batch_nnz > 0) { + if (use_bounded_stage) { + // Bounded staging: copy native values/indices into a small pinned + // slot instead of page-locking the whole host CSC. + stage->wait(s); + if (batch_nnz > 0) { + host_copy_slice(h_data, h_indices, (size_t)ptr_start, batch_nnz, + stage->template get<0>(s), + stage->template get<1>(s)); + cudaMemcpyAsync(buf.d_sparse_data_orig, + stage->template get<0>(s), + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, stage->template get<1>(s), + (size_t)batch_nnz * sizeof(IndexT), + cudaMemcpyHostToDevice, stream); + } + stage->record(s, stream); + } else if (batch_nnz > 0) { cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, (size_t)batch_nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); @@ -263,24 +279,12 @@ static void ovr_sparse_csr_host_rowstream_impl( size_t cap = SAFE_BATCH_NNZ; size_t mem_cap = budget / BYTES_PER_NNZ; if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; - sub_batch_cols = cap_sub_batch_by_nnz( - n_cols, sub_batch_cols, cap, [&](int c) { return h_col_counts[c]; }); - - int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - size_t max_batch_nnz = 0; - std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - std::vector h_batch_nnz(n_batches); - for (int b = 0; b < n_batches; b++) { - int col_start = b * sub_batch_cols; - int sb = std::min(sub_batch_cols, n_cols - col_start); - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i < sb; i++) - off[i + 1] = - checked_int_span((size_t)off[i] + h_col_counts[col_start + i], - "rowstream rebased column offsets"); - h_batch_nnz[b] = (size_t)off[sb]; - if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; - } + ColumnBatchPlan batches = plan_column_batches_from_counts( + n_cols, sub_batch_cols, cap, [&](int c) { return h_col_counts[c]; }, + "rowstream rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + int n_batches = batches.n_batches; + size_t max_batch_nnz = batches.max_nnz; size_t cub_temp_bytes = 0; if (max_batch_nnz > 0) { @@ -299,11 +303,8 @@ static void ovr_sparse_csr_host_rowstream_impl( // NOT page-locked: gather reads it on CPU, only compacted slice crosses // bus. size_t stage_nnz = max_batch_nnz ? max_batch_nnz : 1; - std::vector h_gather_vals(stage_nnz); - std::vector h_gather_cols(stage_nnz); - std::vector h_gather_indptr(n_rows + 1); - HostRegisterGuard pin_gvals(h_gather_vals.data(), stage_nnz * sizeof(InT)); - HostRegisterGuard pin_gcols(h_gather_cols.data(), stage_nnz * sizeof(int)); + PinnedRing gather_stage(1, stage_nnz); + PinnedRing indptr_stage(1, (size_t)n_rows + 1); std::vector cursor(n_rows, 0); // offset within each (sorted) row int* d_group_codes = pool.alloc(n_rows); @@ -318,6 +319,7 @@ static void ovr_sparse_csr_host_rowstream_impl( int* d_gather_indptr = pool.alloc(n_rows + 1); int* col_offsets = pool.alloc(sub_batch_cols + 1); int* write_pos = pool.alloc(sub_batch_cols); + int* d_all_offsets = upload_batch_offsets(batches, pool); InT* csc_vals_orig = pool.alloc(max_batch_nnz); float* csc_vals_f32 = pool.alloc(max_batch_nnz); int* csc_row_idx = pool.alloc(max_batch_nnz); @@ -335,6 +337,8 @@ static void ovr_sparse_csr_host_rowstream_impl( double* d_nz_scratch = rank_use_gmem ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; + ScopedCudaStream row_stream(cudaStreamDefault); + cudaStream_t stream = row_stream.get(); // ---- One linear column-batched pass. Cursor advances monotonically // (sorted indices + ascending batches): each nonzero read/transferred once, @@ -346,62 +350,43 @@ static void ovr_sparse_csr_host_rowstream_impl( for (int b = 0; b < n_batches; b++) { int sb_cols = std::min(sub_batch_cols, n_cols - col); int col_end = col + sb_cols; - - // Per-row run for this batch: binary-search sorted indices from cursor - // for first column >= col_end. - host_parallel_ranges(n_rows, [&](int r0, int r1) { - for (int r = r0; r < r1; r++) { - const IndexT* lo = h_indices + h_indptr[r] + cursor[r]; - const IndexT* hi = h_indices + h_indptr[r + 1]; - g_count[r] = - (int)(std::lower_bound(lo, hi, (IndexT)col_end) - lo); - } - }); - // Prefix sum -> per-row output offsets (gather mini-CSR row pointer). - h_gather_indptr[0] = 0; - for (int r = 0; r < n_rows; r++) - h_gather_indptr[r + 1] = checked_int_span( - (size_t)h_gather_indptr[r] + (size_t)g_count[r], - "rowstream gather nnz"); - int batch_nnz = h_gather_indptr[n_rows]; - // Copy each row's run into its slot, advance cursor (disjoint outputs - // -> race-free). - host_parallel_ranges(n_rows, [&](int r0, int r1) { - for (int r = r0; r < r1; r++) { - IndptrT base = h_indptr[r] + cursor[r]; - size_t gpos = (size_t)h_gather_indptr[r]; - int cnt = g_count[r]; - for (int k = 0; k < cnt; k++) { - h_gather_vals[gpos + k] = h_data[base + k]; - h_gather_cols[gpos + k] = (int)h_indices[base + k]; - } - cursor[r] += cnt; - } - }); - - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - cudaMemcpy(col_offsets, off, (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(write_pos, off, sb_cols * sizeof(int), - cudaMemcpyHostToDevice); + gather_stage.wait(0); + indptr_stage.wait(0); + InT* h_gather_vals = gather_stage.template get<0>(0); + int* h_gather_cols = gather_stage.template get<1>(0); + int* h_gather_indptr = indptr_stage.template get<0>(0); + + int batch_nnz = host_materialize_csr_column_interval_cursor( + h_data, h_indices, h_indptr, n_rows, col, col_end, cursor.data(), + g_count.data(), h_gather_indptr, h_gather_vals, h_gather_cols, + "rowstream gather nnz"); + + int* off = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(col_offsets, off, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(write_pos, off, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); // Bulk H2D of this batch's compacted nonzeros (1x transfer). if (batch_nnz > 0) { - cuda_check(cudaMemcpy(d_gather_vals, h_gather_vals.data(), - (size_t)batch_nnz * sizeof(InT), - cudaMemcpyHostToDevice), + cuda_check(cudaMemcpyAsync(d_gather_vals, h_gather_vals, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream), "rowstream gathered vals H2D"); - cuda_check(cudaMemcpy(d_gather_cols, h_gather_cols.data(), - (size_t)batch_nnz * sizeof(int), - cudaMemcpyHostToDevice), + cuda_check(cudaMemcpyAsync(d_gather_cols, h_gather_cols, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream), "rowstream gathered cols H2D"); } - cudaMemcpy(d_gather_indptr, h_gather_indptr.data(), - (n_rows + 1) * sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpyAsync(d_gather_indptr, h_gather_indptr, + (n_rows + 1) * sizeof(int), cudaMemcpyHostToDevice, + stream); + gather_stage.record(0, stream); + indptr_stage.record(0, stream); // Scatter mini-CSR into the column-batch CSC accumulator. csr_scatter_to_csc_kernel - <<<(n_rows + tpb - 1) / tpb, tpb>>>( + <<<(n_rows + tpb - 1) / tpb, tpb, 0, stream>>>( d_gather_vals, d_gather_cols, d_gather_indptr, write_pos, csc_vals_orig, csc_row_idx, n_rows, col, col_end, 0); CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); @@ -409,36 +394,38 @@ static void ovr_sparse_csr_host_rowstream_impl( launch_ovr_cast_and_accumulate_sparse( csc_vals_orig, csc_vals_f32, csc_row_idx, col_offsets, d_group_codes, sub_group_sums, sub_group_nnz, sb_cols, n_groups, - compute_nnz, tpb, smem_cast, cast_use_gmem, 0); + compute_nnz, tpb, smem_cast, cast_use_gmem, stream); if (batch_nnz > 0) { cub_segmented_sortpairs(cub_temp, cub_temp_bytes, csc_vals_f32, keys_out, csc_row_idx, vals_out, batch_nnz, - sb_cols, col_offsets, col_offsets + 1, 0, - "rowstream segmented sort"); + sb_cols, col_offsets, col_offsets + 1, + stream, "rowstream segmented sort"); } launch_ovr_sparse_rank( keys_out, vals_out, col_offsets, d_group_codes, d_group_sizes, sub_rank_sums, sub_tie_corr, d_nz_scratch, n_rows, sb_cols, - n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, 0); + n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); - cudaMemcpy2D(d_rank_sums + col, n_cols * sizeof(double), sub_rank_sums, - sb_cols * sizeof(double), sb_cols * sizeof(double), - n_groups, cudaMemcpyDeviceToDevice); + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) - cudaMemcpy(d_tie_corr + col, sub_tie_corr, sb_cols * sizeof(double), - cudaMemcpyDeviceToDevice); - cudaMemcpy2D(d_group_sums + col, n_cols * sizeof(double), - sub_group_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice); + cudaMemcpyAsync(d_tie_corr + col, sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + sub_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); if (compute_nnz) - cudaMemcpy2D(d_group_nnz + col, n_cols * sizeof(double), - sub_group_nnz, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice); + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + sub_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); col += sb_cols; } - cuda_check(cudaDeviceSynchronize(), "rowstream sync"); + cuda_check(cudaStreamSynchronize(stream), "rowstream sync"); } /** @@ -532,30 +519,14 @@ static void ovr_sparse_csr_host_streaming_impl( size_t batch_nnz_cap = SAFE_BATCH_NNZ; size_t mem_cap = budget / (size_t)N_STREAMS / BYTES_PER_NNZ; if (mem_cap > 0 && mem_cap < batch_nnz_cap) batch_nnz_cap = mem_cap; - sub_batch_cols = - cap_sub_batch_by_nnz(n_cols, sub_batch_cols, batch_nnz_cap, - [&](int c) { return (size_t)h_col_counts[c]; }); - - int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - size_t max_batch_nnz = 0; - std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - std::vector h_batch_nnz(n_batches); - for (int b = 0; b < n_batches; b++) { - int col_start = b * sub_batch_cols; - int sb_cols = std::min(sub_batch_cols, n_cols - col_start); - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i < sb_cols; i++) - off[i + 1] = checked_int_span( - (size_t)off[i] + (size_t)h_col_counts[col_start + i], - "OVR host CSR rebased column offsets"); - h_batch_nnz[b] = (size_t)off[sb_cols]; - if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; - } - - int* d_all_offsets = - pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); - cudaMemcpy(d_all_offsets, h_all_offsets.data(), - h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + ColumnBatchPlan batches = plan_column_batches_from_counts( + n_cols, sub_batch_cols, batch_nnz_cap, + [&](int c) { return (size_t)h_col_counts[c]; }, + "OVR host CSR rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + int n_batches = batches.n_batches; + size_t max_batch_nnz = batches.max_nnz; + int* d_all_offsets = upload_batch_offsets(batches, pool); // ---- Phase 1: per-stream bounded work buffer size + stream count ---- size_t cub_temp_bytes = 0; @@ -673,7 +644,7 @@ static void ovr_sparse_csr_host_streaming_impl( auto stream = streams[s]; auto& buf = bufs[s]; int batch_nnz = - checked_int_span(h_batch_nnz[b], "OVR host CSR active batch nnz"); + checked_int_span(batches.nnz[b], "OVR host CSR active batch nnz"); int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), @@ -748,26 +719,17 @@ static void ovr_sparse_csc_streaming_impl( cudaMemcpyDeviceToHost); // Bound each batch's nnz: CUB item counts within int32 + sort buffers fit. - { - constexpr size_t BYTES_PER_NNZ = - 2 * sizeof(float) + 2 * sizeof(int) + 8; - size_t cap = SAFE_BATCH_NNZ; - size_t mem_cap = - rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; - if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; - sub_batch_cols = cap_sub_batch_by_nnz( - n_cols, sub_batch_cols, cap, - [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }); - } - + constexpr size_t BYTES_PER_NNZ = 2 * sizeof(float) + 2 * sizeof(int) + 8; + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = + rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + ColumnBatchPlan batches = + plan_csc_column_batches(h_indptr.data(), n_cols, sub_batch_cols, cap, + "OVR device CSC rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); - - size_t max_nnz = 0; - for (int col = 0; col < n_cols; col += sub_batch_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); - if (nnz > max_nnz) max_nnz = nnz; - } + size_t max_nnz = batches.max_nnz; size_t cub_temp_bytes = 0; if (max_nnz > 0) { @@ -916,42 +878,21 @@ static void ovr_sparse_csr_streaming_impl( // Bound each batch's nnz: CUB item counts within int32 + transpose/sort // buffers fit. - { - constexpr size_t BYTES_PER_NNZ = - 2 * sizeof(float) + 2 * sizeof(int) + 8; - size_t cap = SAFE_BATCH_NNZ; - size_t mem_cap = - rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; - if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; - sub_batch_cols = cap_sub_batch_by_nnz( - n_cols, sub_batch_cols, cap, - [&](int c) { return (size_t)h_col_counts[c]; }); - } - - // Per-batch prefix sums on host; flat n_batches x (sub_batch_cols+1). - int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - size_t max_batch_nnz = 0; - std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - std::vector h_batch_nnz(n_batches); - - for (int b = 0; b < n_batches; b++) { - int col_start = b * sub_batch_cols; - int sb_cols = std::min(sub_batch_cols, n_cols - col_start); - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - off[0] = 0; - for (int i = 0; i < sb_cols; i++) - off[i + 1] = checked_int_span( - (size_t)off[i] + (size_t)h_col_counts[col_start + i], - "OVR device CSR rebased column offsets"); - h_batch_nnz[b] = (size_t)off[sb_cols]; - if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; - } + constexpr size_t BYTES_PER_NNZ = 2 * sizeof(float) + 2 * sizeof(int) + 8; + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = + rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + ColumnBatchPlan batches = plan_column_batches_from_counts( + n_cols, sub_batch_cols, cap, + [&](int c) { return (size_t)h_col_counts[c]; }, + "OVR device CSR rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + int n_batches = batches.n_batches; + size_t max_batch_nnz = batches.max_nnz; // Upload all batch offsets in one H2D. - int* d_all_offsets = - pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); - cudaMemcpy(d_all_offsets, h_all_offsets.data(), - h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + int* d_all_offsets = upload_batch_offsets(batches, pool); // ---- Phase 1: per-stream buffers ---- size_t cub_temp_bytes = 0; @@ -1028,7 +969,7 @@ static void ovr_sparse_csr_streaming_impl( auto stream = streams[s]; auto& buf = bufs[s]; int batch_nnz = - checked_int_span(h_batch_nnz[b], "OVR device CSR active batch nnz"); + checked_int_span(batches.nnz[b], "OVR device CSR active batch nnz"); int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 5ada25d3..e867fea8 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -35,54 +35,6 @@ OVO_DEVICE_SPARSE_SUB_BATCH = 128 OVR_DENSE_SUB_BATCH = 64 OVO_DENSE_TIERED_SUB_BATCH = 256 -DENSE_HOST_PRELOAD_MAX_GPU_FRACTION = 0.55 # leave headroom for rank buffers - - -def _maybe_preload_host_dense(rg: _RankGenes) -> None: - """Preload a moderate host-dense matrix to the GPU (OVO only). - - OVO caches the sorted reference group on the device and ranks every other - group against it, so staging the matrix on the GPU once is intended -- it - avoids re-streaming the shared reference per group. The OVR path instead - streams column chunks from host (see ``ovr_rank_dense_host_streaming``). - """ - X = rg.X - if not isinstance(X, np.ndarray) or X.size == 0: - return - - try: - free, _ = cp.cuda.runtime.memGetInfo() - except cp.cuda.runtime.CUDARuntimeError: - return - - # Gate on *free* (the rmm_available_device_bytes convention), not total: - # under an RMM pool an array below a fraction of total can still exceed free. - if X.nbytes > free * DENSE_HOST_PRELOAD_MAX_GPU_FRACTION: - return - - registered = False - if X.flags.c_contiguous or X.flags.f_contiguous: - try: - cp.cuda.runtime.hostRegister(X.ctypes.data, X.nbytes, 0) - registered = True - except cp.cuda.runtime.CUDARuntimeError: - registered = False - - try: - X_gpu = cp.asarray(X) - cp.cuda.get_current_stream().synchronize() - # Under RMM an OOM surfaces as a bare MemoryError (std::bad_alloc), which - # also subsumes cupy's OutOfMemoryError subclass. - except (MemoryError, cp.cuda.runtime.CUDARuntimeError): - cp.get_default_memory_pool().free_all_blocks() - return - finally: - if registered: - try: - cp.cuda.runtime.hostUnregister(X.ctypes.data) - except cp.cuda.runtime.CUDARuntimeError: - pass - rg.X = X_gpu def _choose_wilcoxon_chunk_size(requested: int | None, n_genes: int) -> int: @@ -200,6 +152,49 @@ def _fill_ovo_stats_from_accumulators( rg._compute_stats_in_chunks = False +def _fill_ovo_dense_stats_from_accumulators( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + group_sum_sq_slots: cp.ndarray, + group_nnz_slots: cp.ndarray, + *, + group_sizes: NDArray, + test_group_indices: list[int], + n_ref: int, +) -> None: + n_test = len(test_group_indices) + n_genes = int(group_sums_slots.shape[1]) + n_groups = len(rg.groups_order) + slot_group_indices = np.empty(n_test + 1, dtype=np.intp) + slot_group_indices[:n_test] = np.asarray(test_group_indices, dtype=np.intp) + slot_group_indices[n_test] = rg.ireference + slot_sizes = np.empty(n_test + 1, dtype=np.float64) + slot_sizes[:n_test] = group_sizes[slot_group_indices[:n_test]] + slot_sizes[n_test] = n_ref + slot_sizes_dev = cp.asarray(slot_sizes, dtype=cp.float64)[:, None] + + rg.means = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.vars = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.pts = np.zeros((n_groups, n_genes), dtype=np.float64) if rg.comp_pts else None + + means_slots = group_sums_slots / slot_sizes_dev + vars_slots = group_sum_sq_slots / slot_sizes_dev - means_slots**2 + vars_slots = cp.where( + slot_sizes_dev > 1.0, + vars_slots * slot_sizes_dev / (slot_sizes_dev - 1.0), + 0.0, + ) + rg.means[slot_group_indices] = cp.asnumpy(means_slots) + rg.vars[slot_group_indices] = cp.asnumpy(vars_slots) + if rg.comp_pts: + rg.pts[slot_group_indices] = cp.asnumpy(group_nnz_slots / slot_sizes_dev) + + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + def _ovo_logfoldchanges_from_sums( rg: _RankGenes, group_sums_slots: cp.ndarray, @@ -542,10 +537,8 @@ def wilcoxon( return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" - # OVO caches the reference on the GPU and ranks each group against it, so - # preloading host-dense input is intended. OVR streams chunks from host. - if rg.ireference is not None: - _maybe_preload_host_dense(rg) + # Host dense OVR and OVO stream column windows from host. Already-device + # dense OVO still uses the device-resident tiered planner. # Aggregate if on GPU, else defer to chunks. rg._basic_stats() X = rg.X @@ -1159,6 +1152,88 @@ def _wilcoxon_with_reference( logfoldchanges_gpu=None, ) + if isinstance(X, np.ndarray): + if X.dtype.kind != "f" or X.dtype.itemsize < 4: + X = X.astype(np.float32) + if X.flags.f_contiguous: + buf, f_order = X.ravel(order="K"), True + elif X.flags.c_contiguous: + buf, f_order = X.ravel(order="K"), False + else: + buf, f_order = np.ascontiguousarray(X).ravel(order="K"), False + dense_sub_batch_cols = ( + _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + if chunk_size is not None + else OVO_DENSE_TIERED_SUB_BATCH + ) + + rank_sums = cp.zeros((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + compute_stats = rg._compute_stats_in_chunks + compute_nnz = compute_stats and rg.comp_pts + n_groups_stats = n_test + 1 + stats_shape = (n_groups_stats, n_total_genes) if compute_stats else (1, 1) + group_sums = cp.empty(stats_shape, dtype=cp.float64) + group_sum_sq = cp.empty(stats_shape, dtype=cp.float64) + group_nnz = cp.empty( + stats_shape if compute_nnz else (1, 1), + dtype=cp.float64, + ) + + _wc.ovo_rank_dense_host_streaming( + buf, + ref_row_ids, + all_grp_row_ids, + offsets_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sum_sq, + group_nnz, + n_full_rows=X.shape[0], + n_cols=n_total_genes, + n_groups=n_test, + f_order=f_order, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + compute_stats=compute_stats, + sub_batch_cols=dense_sub_batch_cols, + ) + + logfoldchanges_gpu = None + if compute_stats: + if rg._store_wilcoxon_gpu_result and not rg.comp_pts: + logfoldchanges_gpu = _ovo_logfoldchanges_from_sums( + rg, + group_sums, + test_sizes, + n_ref, + ) + rg._compute_stats_in_chunks = False + else: + _fill_ovo_dense_stats_from_accumulators( + rg, + group_sums, + group_sum_sq, + group_nnz, + group_sizes=group_sizes, + test_group_indices=test_group_indices, + n_ref=n_ref, + ) + + return _finish_ovo( + rank_sums, + test_sizes, + n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=test_group_indices, + logfoldchanges_gpu=logfoldchanges_gpu, + ) + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) scores_host = np.empty((n_test, n_total_genes), dtype=np.float64) From 3167d31c0906c9677b9476045fc5a707c293133c Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 26 Jun 2026 12:29:10 +0200 Subject: [PATCH 34/36] update python Signed-off-by: Intron7 --- .github/workflows/docker.yml | 2 +- CMakeLists.txt | 59 +- docker/Dockerfile.deps | 4 +- docker/docker-push.sh | 2 +- src/rapids_singlecell/_cuda/nb_types.h | 4 + .../_cuda/wilcoxon/wilcoxon.cu | 119 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 17 +- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 16 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 99 +- .../_cuda/wilcoxon/wilcoxon_sparse.cu | 102 +- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 69 +- .../tools/_rank_genes_groups/_utils.py | 2 + .../tools/_rank_genes_groups/_wilcoxon.py | 1515 +++++++++-------- tests/test_rank_genes_groups_wilcoxon.py | 101 +- 14 files changed, 1229 insertions(+), 882 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index e76070b6..7cb46339 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -74,7 +74,7 @@ jobs: - "26.04" CUDA_SUFFIX: - { ver: "12.9.1", label: "cuda12", pkg: "cu12" } - - { ver: "13.0.2", label: "cuda13", pkg: "cu13" } + - { ver: "13.1.0", label: "cuda13", pkg: "cu13" } name: Build Docker images (${{ matrix.CUDA_SUFFIX.label }}) runs-on: ubuntu-latest permissions: diff --git a/CMakeLists.txt b/CMakeLists.txt index c98171b5..935945c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,7 +48,9 @@ if (RSC_BUILD_EXTENSIONS) ERROR_QUIET ) if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") - list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}") + set(_rsc_python_rmm_hint "${RSC_PYTHON_RMM_DIR}") + else() + set(_rsc_python_rmm_hint "") endif() # Wheel builds install librmm/rapids_logger into the isolated build env and # write build/.librmm_dir from CIBW_BEFORE_BUILD. publish.yml also symlinks @@ -91,6 +93,10 @@ if (RSC_BUILD_EXTENSIONS) get_filename_component(_rsc_path_prefix "${_rsc_path_entry}/.." ABSOLUTE) _rsc_collect_rapids_python_prefix("${_rsc_path_prefix}") endforeach() + if (NOT RSC_RMM_HINTS + AND NOT "${_rsc_python_rmm_hint}" STREQUAL "") + list(APPEND RSC_RMM_HINTS "${_rsc_python_rmm_hint}") + endif() if (RSC_RAPIDS_CMAKE_PREFIXES) list(APPEND CMAKE_PREFIX_PATH ${RSC_RAPIDS_CMAKE_PREFIXES}) if (RSC_CCCL_HINTS) @@ -107,7 +113,9 @@ if (RSC_BUILD_EXTENSIONS) endif() endif() if (RSC_RMM_HINTS) - find_package(rmm CONFIG REQUIRED HINTS ${RSC_RMM_HINTS}) + list(GET RSC_RMM_HINTS 0 _rsc_rmm_dir) + set(rmm_DIR "${_rsc_rmm_dir}" CACHE PATH "Path to rmm package config" FORCE) + find_package(rmm CONFIG REQUIRED) else() find_package(rmm CONFIG REQUIRED) endif() @@ -189,16 +197,55 @@ function(add_nb_cuda_module target src) endfunction() # An RMM-backed nanobind CUDA module: add_nb_cuda_module plus the shared RMM -# scratch allocator (rmm_scratch.cu) and the rmm::rmm link. librmm.so is resolved -# at runtime via the cuML preload (rapids_singlecell/__init__.py imports cuML -# before these extensions, loading librmm into the process), so no INSTALL_RPATH -# is needed. Reusable by any module that needs RMM device scratch. +# scratch allocator (rmm_scratch.cu) and the rmm::rmm link. Installed wheels +# resolve RAPIDS runtime libs from sibling Python packages; editable source-tree +# imports still have the _cuda/__init__.py preload fallback. function(add_rmm_cuda_module target src) add_nb_cuda_module(${target} ${src}) if (RSC_BUILD_EXTENSIONS) target_sources(${target} PRIVATE src/rapids_singlecell/_cuda/rmm_scratch.cu) target_link_libraries(${target} PRIVATE rmm::rmm) + set(_rsc_rmm_build_rpath) + set(_rsc_rmm_have_build_librmm FALSE) + set(_rsc_rmm_have_build_rapids_logger FALSE) + if (DEFINED ENV{CONDA_PREFIX}) + set(_rsc_rmm_env_site + "$ENV{CONDA_PREFIX}/lib/python${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}/site-packages") + if (EXISTS "${_rsc_rmm_env_site}/librmm/lib64") + list(APPEND _rsc_rmm_build_rpath + "${_rsc_rmm_env_site}/librmm/lib64") + set(_rsc_rmm_have_build_librmm TRUE) + endif() + if (EXISTS "${_rsc_rmm_env_site}/rapids_logger/lib64") + list(APPEND _rsc_rmm_build_rpath + "${_rsc_rmm_env_site}/rapids_logger/lib64") + set(_rsc_rmm_have_build_rapids_logger TRUE) + endif() + endif() + if (NOT _rsc_rmm_have_build_librmm AND rmm_DIR) + get_filename_component(_rsc_rmm_build_librmm_dir + "${rmm_DIR}/../.." REALPATH) + list(APPEND _rsc_rmm_build_rpath "${_rsc_rmm_build_librmm_dir}") + endif() + if (NOT _rsc_rmm_have_build_rapids_logger AND rapids_logger_DIR) + get_filename_component(_rsc_rmm_build_rapids_logger_dir + "${rapids_logger_DIR}/../.." REALPATH) + list(APPEND _rsc_rmm_build_rpath + "${_rsc_rmm_build_rapids_logger_dir}") + endif() + set(_rsc_rmm_install_rpath + "\$ORIGIN/../../librmm/lib64" + "\$ORIGIN/../../rapids_logger/lib64" + ) + if (CUDAToolkit_LIBRARY_DIR) + list(APPEND _rsc_rmm_build_rpath "${CUDAToolkit_LIBRARY_DIR}") + list(APPEND _rsc_rmm_install_rpath "${CUDAToolkit_LIBRARY_DIR}") + endif() + set_target_properties(${target} PROPERTIES + BUILD_RPATH "${_rsc_rmm_build_rpath}" + INSTALL_RPATH "${_rsc_rmm_install_rpath}" + ) endif() endfunction() diff --git a/docker/Dockerfile.deps b/docker/Dockerfile.deps index 6638a67d..aad3b0a5 100644 --- a/docker/Dockerfile.deps +++ b/docker/Dockerfile.deps @@ -1,4 +1,4 @@ -ARG CUDA_VER=13.0.2 +ARG CUDA_VER=13.1.0 ARG LINUX_VER=ubuntu24.04 FROM nvidia/cuda:${CUDA_VER}-devel-${LINUX_VER} @@ -7,7 +7,7 @@ SHELL ["/bin/bash", "-euo", "pipefail", "-c"] ARG PYTHON_VER=3.13 # Re-declare after FROM so it is available to RUN steps (passed by docker.yml build-args) -ARG CUDA_VER=13.0.2 +ARG CUDA_VER=13.1.0 ENV PATH=/opt/conda/bin:$PATH ENV PYTHON_VERSION=${PYTHON_VER} diff --git a/docker/docker-push.sh b/docker/docker-push.sh index 69801f79..4a137fa7 100755 --- a/docker/docker-push.sh +++ b/docker/docker-push.sh @@ -6,7 +6,7 @@ rapids_version=26.04 declare -A cuda_versions=( [cu12]="12.8.0" - [cu13]="13.0.2" + [cu13]="13.1.0" ) declare -A cuda_archs=( diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 23ba8958..855d36a9 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -121,6 +121,10 @@ using gpu_array_contig = nb::ndarray; // Host (NumPy) array aliases template using host_array = nb::ndarray>; +template +using host_array_c2 = nb::ndarray, nb::c_contig>; +template +using host_array_f2 = nb::ndarray, nb::f_contig>; // Register bindings for both regular CUDA and managed-memory arrays. // Usage: diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 74d75574..01b4eb7f 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -130,13 +130,15 @@ static void launch_ovr_rank_dense_streaming( template static void launch_ovr_rank_dense_host_streaming( const T* h_X, bool f_order, const int* group_codes, double* rank_sums, - double* tie_corr, double* group_sums, double* group_nnz, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, bool compute_nnz, + double* tie_corr, double* group_sums, double* group_nnz, double* total_sums, + double* total_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_nnz, bool compute_totals, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; const bool compute_stats = group_sums != nullptr; compute_nnz = compute_nnz && (group_nnz != nullptr); + compute_totals = compute_stats && compute_totals && (total_sums != nullptr); // F-order float32 input feeds the sort directly (no cast/transpose buffer). const bool fast_keys = f_order && std::is_same::value; @@ -161,7 +163,10 @@ static void launch_ovr_rank_dense_host_streaming( (size_t)sub_batch_cols * sizeof(double) + (compute_stats ? (size_t)n_groups * sub_batch_cols * sizeof(double) : 0) + - (compute_nnz ? (size_t)n_groups * sub_batch_cols * sizeof(double) : 0); + (compute_nnz ? (size_t)n_groups * sub_batch_cols * sizeof(double) : 0) + + (compute_totals ? (size_t)sub_batch_cols * sizeof(double) : 0) + + (compute_totals && compute_nnz ? (size_t)sub_batch_cols * sizeof(double) + : 0); n_streams = clamp_streams_by_budget(n_streams, per_stream_bytes, rmm_available_device_bytes(0.8)); @@ -185,6 +190,8 @@ static void launch_ovr_rank_dense_host_streaming( double* sub_tie_corr; double* sub_group_sums; double* sub_group_nnz; + double* sub_total_sums; + double* sub_total_nnz; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; ++s) { @@ -205,6 +212,11 @@ static void launch_ovr_rank_dense_host_streaming( bufs[s].sub_group_nnz = compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; + bufs[s].sub_total_sums = + compute_totals ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(sub_batch_cols) + : nullptr; } int tpb_rank = round_up_to_warp(n_rows); @@ -298,11 +310,22 @@ static void launch_ovr_rank_dense_host_streaming( (size_t)n_groups * sb_cols * sizeof(double), stream); } + if (compute_totals) { + cudaMemsetAsync(buf.sub_total_sums, 0, sb_cols * sizeof(double), + stream); + if (compute_nnz) { + cudaMemsetAsync(buf.sub_total_nnz, 0, + sb_cols * sizeof(double), stream); + } + } dense_group_accumulate_kernel <<>>( buf.d_stg, group_codes, buf.sub_group_sums, compute_nnz ? buf.sub_group_nnz : buf.sub_group_sums, - n_rows, sb_cols, n_groups, f_order, compute_nnz); + buf.sub_total_sums, + compute_nnz ? buf.sub_total_nnz : buf.sub_total_sums, + n_rows, sb_cols, n_groups, f_order, compute_nnz, + compute_totals); CUDA_CHECK_LAST_ERROR(dense_group_accumulate_kernel); scatter_cols_2d(group_sums + col, buf.sub_group_sums, n_groups, n_cols, sb_cols, stream); @@ -310,6 +333,16 @@ static void launch_ovr_rank_dense_host_streaming( scatter_cols_2d(group_nnz + col, buf.sub_group_nnz, n_groups, n_cols, sb_cols, stream); } + if (compute_totals) { + cudaMemcpyAsync(total_sums + col, buf.sub_total_sums, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + if (compute_nnz) { + cudaMemcpyAsync(total_nnz + col, buf.sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } } col += sb_cols; @@ -821,19 +854,21 @@ static void launch_ovo_rank_dense_host_streaming( sync_streams(streams, "dense host OVO streaming"); } -template +template static void def_ovr_rank_dense_host_streaming(nb::module_& m) { m.def( "ovr_rank_dense_host_streaming", - [](host_array buf, gpu_array_c group_codes, + [](HostArray X, gpu_array_c group_codes, gpu_array_c rank_sums, gpu_array_c tie_corr, gpu_array_c group_sums, - gpu_array_c group_nnz, int n_rows, int n_cols, - int n_groups, bool f_order, bool compute_tie_corr, bool compute_nnz, - bool compute_stats, int sub_batch_cols) { - nb_require(buf.shape(0) == (size_t)n_rows * (size_t)n_cols, - "ovr_rank_host: buf length must be n_rows*n_cols"); + gpu_array_c group_nnz, + gpu_array_c total_sums, + gpu_array_c total_nnz, int n_groups, + bool compute_tie_corr, bool compute_nnz, bool compute_stats, + bool compute_totals, int sub_batch_cols) { + int n_rows = (int)X.shape(0); + int n_cols = (int)X.shape(1); nb_require((int)group_codes.shape(0) == n_rows, "ovr_rank_host: group_codes length must be n_rows"); nb_require( @@ -843,32 +878,35 @@ static void def_ovr_rank_dense_host_streaming(nb::module_& m) { nb_require((int)tie_corr.shape(0) == n_cols, "ovr_rank_host: tie_corr length must be n_cols"); launch_ovr_rank_dense_host_streaming( - buf.data(), f_order, group_codes.data(), rank_sums.data(), + X.data(), FOrder, group_codes.data(), rank_sums.data(), tie_corr.data(), compute_stats ? group_sums.data() : nullptr, - compute_nnz ? group_nnz.data() : nullptr, n_rows, n_cols, - n_groups, compute_tie_corr, compute_nnz, sub_batch_cols); + compute_nnz ? group_nnz.data() : nullptr, + compute_totals ? total_sums.data() : nullptr, + (compute_totals && compute_nnz) ? total_nnz.data() : nullptr, + n_rows, n_cols, n_groups, compute_tie_corr, compute_nnz, + compute_totals, sub_batch_cols); }, - "buf"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, "group_sums"_a, - "group_nnz"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, - "f_order"_a, "compute_tie_corr"_a, "compute_nnz"_a, "compute_stats"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); + "X"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, "group_sums"_a, + "group_nnz"_a, "total_sums"_a, "total_nnz"_a, nb::kw_only(), + "n_groups"_a, "compute_tie_corr"_a, "compute_nnz"_a, "compute_stats"_a, + "compute_totals"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); } -template +template static void def_ovo_rank_dense_host_streaming(nb::module_& m) { m.def( "ovo_rank_dense_host_streaming", - [](host_array buf, host_array ref_row_ids, + [](HostArray X, host_array ref_row_ids, host_array grp_row_ids, host_array grp_offsets, gpu_array_c rank_sums, gpu_array_c tie_corr, gpu_array_c group_sums, gpu_array_c group_sum_sq, - gpu_array_c group_nnz, int n_full_rows, int n_cols, - int n_groups, bool f_order, bool compute_tie_corr, bool compute_nnz, - bool compute_stats, int sub_batch_cols) { - nb_require(buf.shape(0) == (size_t)n_full_rows * (size_t)n_cols, - "ovo_rank_host: buf length must be n_rows*n_cols"); + gpu_array_c group_nnz, int n_groups, + bool compute_tie_corr, bool compute_nnz, bool compute_stats, + int sub_batch_cols) { + int n_full_rows = (int)X.shape(0); + int n_cols = (int)X.shape(1); int n_ref = (int)ref_row_ids.shape(0); int n_all_grp = (int)grp_row_ids.shape(0); nb_require((int)grp_offsets.shape(0) == n_groups + 1, @@ -904,7 +942,7 @@ static void def_ovo_rank_dense_host_streaming(nb::module_& m) { } } launch_ovo_rank_dense_host_streaming( - buf.data(), f_order, ref_row_ids.data(), grp_row_ids.data(), + X.data(), FOrder, ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), rank_sums.data(), tie_corr.data(), compute_stats ? group_sums.data() : nullptr, compute_stats ? group_sum_sq.data() : nullptr, @@ -912,21 +950,32 @@ static void def_ovo_rank_dense_host_streaming(nb::module_& m) { n_all_grp, n_cols, n_groups, n_groups_stats, compute_tie_corr, compute_nnz, compute_stats, sub_batch_cols); }, - "buf"_a, "ref_row_ids"_a, "grp_row_ids"_a, "grp_offsets"_a, - "rank_sums"_a, "tie_corr"_a, "group_sums"_a, "group_sum_sq"_a, - "group_nnz"_a, nb::kw_only(), "n_full_rows"_a, "n_cols"_a, "n_groups"_a, - "f_order"_a, "compute_tie_corr"_a, "compute_nnz"_a, "compute_stats"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); + "X"_a, "ref_row_ids"_a, "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, "group_sums"_a, "group_sum_sq"_a, "group_nnz"_a, + nb::kw_only(), "n_groups"_a, "compute_tie_corr"_a, "compute_nnz"_a, + "compute_stats"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); } template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - def_ovr_rank_dense_host_streaming(m); - def_ovr_rank_dense_host_streaming(m); - def_ovo_rank_dense_host_streaming(m); - def_ovo_rank_dense_host_streaming(m); + def_ovr_rank_dense_host_streaming, + false>(m); + def_ovr_rank_dense_host_streaming, + true>(m); + def_ovr_rank_dense_host_streaming, false>(m); + def_ovr_rank_dense_host_streaming, true>(m); + def_ovo_rank_dense_host_streaming, + false>(m); + def_ovo_rank_dense_host_streaming, + true>(m); + def_ovo_rank_dense_host_streaming, false>(m); + def_ovo_rank_dense_host_streaming, true>(m); m.def( "ovo_rank_dense_tiered_unsorted_ref", diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 6ad19806..4812881a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -264,8 +264,8 @@ static void ovo_streaming_csc_host_impl( int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); bool cast_use_gmem = false; - size_t smem_cast = - cast_accumulate_smem_config(n_groups_stats, compute_nnz, cast_use_gmem); + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_nnz, /*compute_totals=*/false, cast_use_gmem); int col = 0; int batch_idx = 0; @@ -305,8 +305,9 @@ static void ovo_streaming_csc_host_impl( launch_ovr_cast_and_accumulate_sparse( buf.d_sparse_data_f32, buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_nnz, - sb_cols, n_groups_stats, compute_nnz, UTIL_BLOCK_SIZE, smem_cast, - cast_use_gmem, stream); + nullptr, nullptr, sb_cols, n_groups_stats, compute_nnz, + /*compute_totals=*/false, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, + stream); // Extract ref from CSC via row_map, sort cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), @@ -419,14 +420,17 @@ static void ovo_streaming_csr_host_impl( if (max_pack_rows == 0) return; RmmScratchPool pool; + ScopedCudaStream ref_stream(cudaStreamNonBlocking); if (compute_sums) { cudaMemsetAsync(d_group_sums, 0, - (size_t)n_groups_stats * n_cols * sizeof(double)); + (size_t)n_groups_stats * n_cols * sizeof(double), + ref_stream); } if (compute_nnz) { cudaMemsetAsync(d_group_nnz, 0, - (size_t)n_groups_stats * n_cols * sizeof(double)); + (size_t)n_groups_stats * n_cols * sizeof(double), + ref_stream); } // No full-matrix page-lock (the 280GB cudaHostRegister was ~7s/call). The @@ -478,7 +482,6 @@ static void ovo_streaming_csr_host_impl( int ref_chunk_items_i32 = checked_cub_items(ref_chunk_items, "OVO host CSR ref column chunk"); float* d_ref_sorted = pool.alloc(ref_items); - ScopedCudaStream ref_stream(cudaStreamNonBlocking); { ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index 3b77d9be..e516d35e 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -104,15 +104,23 @@ __global__ void dense_block_to_f32_kernel(const T* __restrict__ stg, template __global__ void dense_group_accumulate_kernel( const T* __restrict__ stg, const int* __restrict__ group_codes, - double* __restrict__ group_sums, double* __restrict__ group_nnz, int n_rows, - int sb_cols, int n_groups, bool f_order, bool compute_nnz) { + double* __restrict__ group_sums, double* __restrict__ group_nnz, + double* __restrict__ total_sums, double* __restrict__ total_nnz, int n_rows, + int sb_cols, int n_groups, bool f_order, bool compute_nnz, + bool compute_totals) { int col = blockIdx.x; if (col >= sb_cols) return; for (int row = threadIdx.x; row < n_rows; row += blockDim.x) { - int g = group_codes[row]; - if (g < 0 || g >= n_groups) continue; double v = f_order ? (double)stg[(long long)col * n_rows + row] : (double)stg[(long long)row * sb_cols + col]; + if (compute_totals) { + atomicAdd(&total_sums[col], v); + if (compute_nnz && v != 0.0) { + atomicAdd(&total_nnz[col], 1.0); + } + } + int g = group_codes[row]; + if (g < 0 || g >= n_groups) continue; atomicAdd(&group_sums[(long long)g * sb_cols + col], v); if (compute_nnz && v != 0.0) { atomicAdd(&group_nnz[(long long)g * sb_cols + col], 1.0); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index d435df6c..8aac2f45 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -9,8 +9,9 @@ template static void ovr_sparse_csc_host_streaming_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_nnz, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, bool compute_nnz, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, + double* d_total_sums, double* d_total_nnz, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, bool compute_nnz, bool compute_totals, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; @@ -73,6 +74,8 @@ static void ovr_sparse_csc_host_streaming_impl( double* d_tie_corr; double* d_group_sums; double* d_group_nnz; + double* d_total_sums; + double* d_total_nnz; double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem }; std::vector bufs(n_streams); @@ -94,6 +97,11 @@ static void ovr_sparse_csc_host_streaming_impl( bufs[s].d_group_nnz = compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; + bufs[s].d_total_sums = + compute_totals ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].d_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(sub_batch_cols) + : nullptr; } cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), @@ -108,8 +116,8 @@ static void ovr_sparse_csc_host_streaming_impl( bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); bool cast_use_gmem = false; - size_t smem_cast = - cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); // gmem mode: rank kernel accumulates into rank_sums directly, needs a // per-stream nz_count scratch buffer sized (n_groups, sb_cols). @@ -186,8 +194,8 @@ static void ovr_sparse_csc_host_streaming_impl( launch_ovr_cast_and_accumulate_sparse( buf.d_sparse_data_orig, buf.d_sparse_data_f32, idx32, buf.d_seg_offsets, d_group_codes, buf.d_group_sums, buf.d_group_nnz, - sb_cols, n_groups, compute_nnz, tpb, smem_cast, cast_use_gmem, - stream); + buf.d_total_sums, buf.d_total_nnz, sb_cols, n_groups, compute_nnz, + compute_totals, tpb, smem_cast, cast_use_gmem, stream); // Sort only stored nonzeros (float32 keys) if (batch_nnz > 0) { @@ -217,6 +225,16 @@ static void ovr_sparse_csc_host_streaming_impl( scatter_cols_2d(d_group_nnz + col, buf.d_group_nnz, n_groups, n_cols, sb_cols, stream); } + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, buf.d_total_sums, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, buf.d_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } col += sb_cols; batch_idx++; @@ -243,8 +261,9 @@ template static void ovr_sparse_csr_host_rowstream_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_nnz, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, bool compute_nnz, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, + double* d_total_sums, double* d_total_nnz, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, bool compute_nnz, bool compute_totals, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; size_t total_nnz = (size_t)h_indptr[n_rows]; @@ -296,8 +315,8 @@ static void ovr_sparse_csr_host_rowstream_impl( bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); bool cast_use_gmem = false; - size_t smem_cast = - cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); // ---- Host gather staging (pinned for bulk H2D) + per-row cursor. Full CSR // NOT page-locked: gather reads it on CPU, only compacted slice crosses @@ -334,6 +353,11 @@ static void ovr_sparse_csr_host_rowstream_impl( double* sub_group_nnz = compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; + double* sub_total_sums = + compute_totals ? pool.alloc(sub_batch_cols) : nullptr; + double* sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(sub_batch_cols) + : nullptr; double* d_nz_scratch = rank_use_gmem ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; @@ -393,8 +417,9 @@ static void ovr_sparse_csr_host_rowstream_impl( launch_ovr_cast_and_accumulate_sparse( csc_vals_orig, csc_vals_f32, csc_row_idx, col_offsets, - d_group_codes, sub_group_sums, sub_group_nnz, sb_cols, n_groups, - compute_nnz, tpb, smem_cast, cast_use_gmem, stream); + d_group_codes, sub_group_sums, sub_group_nnz, sub_total_sums, + sub_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals, tpb, + smem_cast, cast_use_gmem, stream); if (batch_nnz > 0) { cub_segmented_sortpairs(cub_temp, cub_temp_bytes, csc_vals_f32, keys_out, csc_row_idx, vals_out, batch_nnz, @@ -423,6 +448,16 @@ static void ovr_sparse_csr_host_rowstream_impl( sub_group_nnz, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, cudaMemcpyDeviceToDevice, stream); + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, sub_total_sums, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } col += sb_cols; } cuda_check(cudaStreamSynchronize(stream), "rowstream sync"); @@ -438,8 +473,9 @@ template static void ovr_sparse_csr_host_streaming_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_nnz, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, bool compute_nnz, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, + double* d_total_sums, double* d_total_nnz, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, bool compute_nnz, bool compute_totals, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; @@ -462,8 +498,9 @@ static void ovr_sparse_csr_host_streaming_impl( if (total_nnz > 0 && data_bytes + idx_bytes > (budget * 3) / 4) { ovr_sparse_csr_host_rowstream_impl( h_data, h_indices, h_indptr, h_group_codes, h_group_sizes, - d_rank_sums, d_tie_corr, d_group_sums, d_group_nnz, n_rows, n_cols, - n_groups, compute_tie_corr, compute_nnz, sub_batch_cols); + d_rank_sums, d_tie_corr, d_group_sums, d_group_nnz, d_total_sums, + d_total_nnz, n_rows, n_cols, n_groups, compute_tie_corr, + compute_nnz, compute_totals, sub_batch_cols); return; } @@ -540,8 +577,8 @@ static void ovr_sparse_csr_host_streaming_impl( bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); bool cast_use_gmem = false; - size_t smem_cast = - cast_accumulate_smem_config(n_groups, compute_nnz, cast_use_gmem); + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); size_t per_stream_bytes = max_batch_nnz * (sizeof(InT) + sizeof(float) + 2 * sizeof(int)) + @@ -551,6 +588,12 @@ static void ovr_sparse_csr_host_streaming_impl( if (compute_nnz) { per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); } + if (compute_totals) { + per_stream_bytes += sub_batch_cols * sizeof(double); + if (compute_nnz) { + per_stream_bytes += sub_batch_cols * sizeof(double); + } + } if (rank_use_gmem) { per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); } @@ -608,6 +651,8 @@ static void ovr_sparse_csr_host_streaming_impl( double* sub_tie_corr; double* sub_group_sums; double* sub_group_nnz; + double* sub_total_sums; + double* sub_total_nnz; double* d_nz_scratch; }; std::vector bufs(n_streams); @@ -628,6 +673,11 @@ static void ovr_sparse_csr_host_streaming_impl( bufs[s].sub_group_nnz = compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) : nullptr; + bufs[s].sub_total_sums = + compute_totals ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(sub_batch_cols) + : nullptr; bufs[s].d_nz_scratch = rank_use_gmem ? pool.alloc((size_t)n_groups * sub_batch_cols) @@ -664,7 +714,8 @@ static void ovr_sparse_csr_host_streaming_impl( launch_ovr_cast_and_accumulate_sparse( buf.csc_vals_orig, buf.csc_vals_f32, buf.csc_row_idx, buf.col_offsets, d_group_codes, buf.sub_group_sums, - buf.sub_group_nnz, sb_cols, n_groups, compute_nnz, tpb, smem_cast, + buf.sub_group_nnz, buf.sub_total_sums, buf.sub_total_nnz, sb_cols, + n_groups, compute_nnz, compute_totals, tpb, smem_cast, cast_use_gmem, stream); if (batch_nnz > 0) { @@ -694,6 +745,16 @@ static void ovr_sparse_csr_host_streaming_impl( scatter_cols_2d(d_group_nnz + col, buf.sub_group_nnz, n_groups, n_cols, sb_cols, stream); } + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, buf.sub_total_sums, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, buf.sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } col += sb_cols; } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index c24913bc..55d28b63 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -43,16 +43,12 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device", ovr_sparse_csc_streaming_impl, int, int); - RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device_i64", - ovr_sparse_csc_streaming_impl, int, int64_t); - RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device_i64_idx64", + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device", ovr_sparse_csc_streaming_impl, int64_t, int64_t); RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device", ovr_sparse_csr_streaming_impl, int, int); - RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device_i64", - ovr_sparse_csr_streaming_impl, int, int64_t); - RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device_i64_idx64", + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device", ovr_sparse_csr_streaming_impl, int64_t, int64_t); #undef RSC_OVR_SPARSE_DEVICE_BINDING @@ -67,36 +63,33 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_rank_sums, \ gpu_array_c d_tie_corr, \ gpu_array_c d_group_sums, \ - gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + gpu_array_c d_group_nnz, \ + gpu_array_c d_total_sums, \ + gpu_array_c d_total_nnz, int n_rows, int n_cols, \ int n_groups, bool compute_tie_corr, bool compute_nnz, \ - int sub_batch_cols) { \ + bool compute_totals, int sub_batch_cols) { \ if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovr_sparse_csc_host_streaming_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_group_codes.data(), h_group_sizes.data(), \ d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_nnz.data(), n_rows, n_cols, n_groups, \ - compute_tie_corr, compute_nnz, sub_batch_cols); \ + d_group_nnz.data(), d_total_sums.data(), d_total_nnz.data(), \ + n_rows, n_cols, n_groups, compute_tie_corr, compute_nnz, \ + compute_totals, sub_batch_cols); \ }, \ "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, \ - "compute_tie_corr"_a, "compute_nnz"_a = true, \ + "d_group_nnz"_a, "d_total_sums"_a, "d_total_nnz"_a, nb::kw_only(), \ + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_nnz"_a = true, "compute_totals"_a = false, \ "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int, int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int, + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", double, int, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int64_t, int64_t); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int, - int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, int, + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", double, int64_t, int64_t); - // int64 row indices: pass natively, downcast to int32 per-batch on-device - // (avoids a full host int32 copy). - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64_idx64", float, - int64_t, int64_t); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64_idx64", double, - int64_t, int64_t); #undef RSC_OVR_SPARSE_CSC_HOST_BINDING #define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ @@ -109,36 +102,33 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_rank_sums, \ gpu_array_c d_tie_corr, \ gpu_array_c d_group_sums, \ - gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + gpu_array_c d_group_nnz, \ + gpu_array_c d_total_sums, \ + gpu_array_c d_total_nnz, int n_rows, int n_cols, \ int n_groups, bool compute_tie_corr, bool compute_nnz, \ - int sub_batch_cols) { \ + bool compute_totals, int sub_batch_cols) { \ if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ ovr_sparse_csr_host_streaming_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_group_codes.data(), h_group_sizes.data(), \ d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_nnz.data(), n_rows, n_cols, n_groups, \ - compute_tie_corr, compute_nnz, sub_batch_cols); \ + d_group_nnz.data(), d_total_sums.data(), d_total_nnz.data(), \ + n_rows, n_cols, n_groups, compute_tie_corr, compute_nnz, \ + compute_totals, sub_batch_cols); \ }, \ "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, \ - "compute_tie_corr"_a, "compute_nnz"_a = true, \ + "d_group_nnz"_a, "d_total_sums"_a, "d_total_nnz"_a, nb::kw_only(), \ + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_nnz"_a = true, "compute_totals"_a = false, \ "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64", float, int, + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", double, int, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int64_t, int64_t); - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64", double, int, - int); - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", double, int64_t, int64_t); - // int64 column indices: pass natively to avoid a full int32 copy of every - // nonzero (~nnz*4 bytes) on large matrices. - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64_idx64", float, - int64_t, int64_t); - RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64_idx64", double, - int64_t, int64_t); #undef RSC_OVR_SPARSE_CSR_HOST_BINDING #define RSC_OVO_DEVICE_BINDING(NAME, IMPL, IndexCType, IndptrCType) \ @@ -167,16 +157,12 @@ void register_sparse_bindings(nb::module_& m) { RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device", ovo_streaming_csc_impl, int, int); - RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device_i64", - ovo_streaming_csc_impl, int, int64_t); - RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device_i64_idx64", - ovo_streaming_csc_impl, int64_t, int64_t); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device", ovo_streaming_csc_impl, + int64_t, int64_t); RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device", ovo_streaming_csr_impl, int, int); - RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device_i64", - ovo_streaming_csr_impl, int, int64_t); - RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device_i64_idx64", - ovo_streaming_csr_impl, int64_t, int64_t); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device", ovo_streaming_csr_impl, + int64_t, int64_t); #undef RSC_OVO_DEVICE_BINDING #define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ @@ -212,16 +198,10 @@ void register_sparse_bindings(nb::module_& m) { "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int, int64_t); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", double, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int64_t, int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", double, int64_t, int64_t); - // int64 row indices: read natively (extraction only, never sorted), - // skipping the full host int32 copy. - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64_idx64", float, int64_t, - int64_t); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64_idx64", double, - int64_t, int64_t); #undef RSC_OVO_CSC_HOST_BINDING #define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ @@ -256,16 +236,10 @@ void register_sparse_bindings(nb::module_& m) { "sub_batch_cols"_a = SUB_BATCH_COLS) RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int, int64_t); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, - int64_t); - // int64 column indices: pass natively to avoid a full int32 copy of every - // nonzero (~nnz*4 bytes) on large matrices. - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64_idx64", float, int64_t, + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", double, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int64_t, int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", double, int64_t, int64_t); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64_idx64", double, - int64_t, int64_t); #undef RSC_OVO_CSR_HOST_BINDING } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 73a51e74..7ecd99ac 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -173,15 +173,16 @@ static inline void launch_ovr_sparse_rank( // selects ovr_cast_and_accumulate_sparse_global_kernel (accumulates in gmem). // Load-bearing fallback, not dead. static size_t cast_accumulate_smem_config(int n_groups, bool compute_nnz, - bool& use_gmem) { + bool compute_totals, bool& use_gmem) { int n_arrays = 1 + (compute_nnz ? 1 : 0); size_t need = (size_t)n_arrays * n_groups * sizeof(double); + if (compute_totals) need += WARP_REDUCE_BUF * sizeof(double); if (need <= wilcoxon_max_smem_per_block()) { use_gmem = false; return need; } use_gmem = true; - return 0; + return compute_totals ? WARP_REDUCE_BUF * sizeof(double) : 0; } // Shared cast+accumulate loop for the two sparse-OVR stats kernels. Casts each @@ -191,11 +192,16 @@ template __device__ __forceinline__ void accumulate_group_stats( const InT* data_in, float* data_f32_out, const IndexT* indices, int seg_start, int seg_end, const int* group_codes, double* sums, - double* nnz, int acc_stride, int n_groups, bool compute_nnz) { + double* nnz, int acc_stride, int n_groups, bool compute_nnz, + bool compute_totals, double& local_total_sum, double& local_total_nnz) { for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { InT v_in = data_in[i]; double v = (double)v_in; data_f32_out[i] = (float)v_in; + if (compute_totals) { + local_total_sum += v; + if (compute_nnz && v != 0.0) local_total_nnz += 1.0; + } int row = (int)indices[i]; int g = group_codes[row]; if (g >= 0 && g < n_groups) { @@ -218,8 +224,9 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, const int* __restrict__ group_codes, double* __restrict__ group_sums, - double* __restrict__ group_nnz, int sb_cols, int n_groups, - bool compute_nnz = true) { + double* __restrict__ group_nnz, double* __restrict__ total_sums, + double* __restrict__ total_nnz, int sb_cols, int n_groups, + bool compute_nnz = true, bool compute_totals = false) { int col = blockIdx.x; if (col >= sb_cols) return; @@ -231,6 +238,7 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( extern __shared__ double smem[]; double* s_sum = smem; double* s_nnz = smem + n_groups; + double* warp_buf = smem + (size_t)(1 + (compute_nnz ? 1 : 0)) * n_groups; for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { s_sum[g] = 0.0; @@ -238,11 +246,25 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( } __syncthreads(); + double local_total_sum = 0.0; + double local_total_nnz = 0.0; accumulate_group_stats( data_in, data_f32_out, indices, seg_start, seg_end, group_codes, s_sum, - s_nnz, /*acc_stride=*/1, n_groups, compute_nnz); + s_nnz, /*acc_stride=*/1, n_groups, compute_nnz, compute_totals, + local_total_sum, local_total_nnz); __syncthreads(); + if (compute_totals) { + double total = wilcoxon_block_sum(local_total_sum, warp_buf); + if (threadIdx.x == 0) total_sums[col] = total; + __syncthreads(); + if (compute_nnz) { + double nnz_total = wilcoxon_block_sum(local_total_nnz, warp_buf); + if (threadIdx.x == 0) total_nnz[col] = nnz_total; + __syncthreads(); + } + } + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { group_sums[(size_t)g * sb_cols + col] = s_sum[g]; if (compute_nnz) { @@ -260,25 +282,40 @@ __global__ void ovr_cast_and_accumulate_sparse_global_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, const int* __restrict__ group_codes, double* __restrict__ group_sums, - double* __restrict__ group_nnz, int sb_cols, int n_groups, - bool compute_nnz = true) { + double* __restrict__ group_nnz, double* __restrict__ total_sums, + double* __restrict__ total_nnz, int sb_cols, int n_groups, + bool compute_nnz = true, bool compute_totals = false) { int col = blockIdx.x; if (col >= sb_cols) return; int seg_start = col_seg_offsets[col]; int seg_end = col_seg_offsets[col + 1]; + extern __shared__ double warp_buf[]; + double local_total_sum = 0.0; + double local_total_nnz = 0.0; accumulate_group_stats( data_in, data_f32_out, indices, seg_start, seg_end, group_codes, group_sums + col, group_nnz + col, - /*acc_stride=*/sb_cols, n_groups, compute_nnz); + /*acc_stride=*/sb_cols, n_groups, compute_nnz, compute_totals, + local_total_sum, local_total_nnz); + if (compute_totals) { + double total = wilcoxon_block_sum(local_total_sum, warp_buf); + if (threadIdx.x == 0) total_sums[col] = total; + __syncthreads(); + if (compute_nnz) { + double nnz_total = wilcoxon_block_sum(local_total_nnz, warp_buf); + if (threadIdx.x == 0) total_nnz[col] = nnz_total; + } + } } template static void launch_ovr_cast_and_accumulate_sparse( const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, const int* d_col_offsets, const int* d_group_codes, double* d_group_sums, - double* d_group_nnz, int sb_cols, int n_groups, bool compute_nnz, int tpb, + double* d_group_nnz, double* d_total_sums, double* d_total_nnz, int sb_cols, + int n_groups, bool compute_nnz, bool compute_totals, int tpb, size_t smem_cast, bool use_gmem, cudaStream_t stream) { if (use_gmem) { size_t stats_items = (size_t)n_groups * sb_cols; @@ -288,17 +325,17 @@ static void launch_ovr_cast_and_accumulate_sparse( stream); } ovr_cast_and_accumulate_sparse_global_kernel - <<>>(d_data_orig, d_data_f32, d_indices, - d_col_offsets, d_group_codes, - d_group_sums, d_group_nnz, sb_cols, - n_groups, compute_nnz); + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_nnz, d_total_sums, + d_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals); CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_global_kernel); } else { ovr_cast_and_accumulate_sparse_kernel <<>>( d_data_orig, d_data_f32, d_indices, d_col_offsets, - d_group_codes, d_group_sums, d_group_nnz, sb_cols, n_groups, - compute_nnz); + d_group_codes, d_group_sums, d_group_nnz, d_total_sums, + d_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals); CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); } } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 913abe89..32a59452 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -26,6 +26,8 @@ def _sparse_has_negative(X) -> bool: Dask sparse separately. Dense and t-test/logreg never need this. """ if sp.issparse(X) or cpsp.issparse(X): + if np.dtype(X.data.dtype).kind == "c": + return False return X.nnz > 0 and float(X.data.min()) < 0 return False diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index e867fea8..eababe07 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from dataclasses import dataclass from typing import TYPE_CHECKING import cupy as cp @@ -37,6 +38,22 @@ OVO_DENSE_TIERED_SUB_BATCH = 256 +@dataclass(frozen=True) +class _OvoContext: + codes: np.ndarray + n_groups: int + ireference: int + n_ref: int + ref_row_ids: np.ndarray + test_group_indices: list[int] + all_grp_row_ids: np.ndarray + offsets_np: np.ndarray + offsets_gpu: cp.ndarray + n_all_grp: int + n_test: int + test_sizes: cp.ndarray + + def _choose_wilcoxon_chunk_size(requested: int | None, n_genes: int) -> int: if requested is not None: return _choose_chunk_size(requested) @@ -362,46 +379,63 @@ def _finish_ovo( ] -def _host_sparse_fn_and_arrays(module, base_name: str, X): +def _host_sparse_data_array(X): data_dtype = np.dtype(X.data.dtype) if data_dtype == np.float64: - is_f64 = True - data_arr = X.data - elif data_dtype == np.float32 or data_dtype.kind in {"b", "i", "u"}: - is_f64 = False - data_arr = X.data.astype(np.float32, copy=False) - else: + return X.data + if data_dtype == np.float32 or data_dtype.kind in {"b", "i", "u"}: + return X.data.astype(np.float32, copy=False) + if data_dtype.kind == "c": + msg = ( + "Wilcoxon sparse input data dtype must be real; complex sparse " + "data is not supported." + ) + raise TypeError(msg) + msg = ( + "Wilcoxon sparse input data dtype must be float32, float64, bool, " + f"or integer; got {data_dtype}." + ) + raise TypeError(msg) + + +def _validate_wilcoxon_sparse_dtype(X) -> None: + if not (sp.issparse(X) or cpsp.issparse(X)): + return + data_dtype = np.dtype(X.data.dtype) + if data_dtype.kind == "c": + msg = ( + "Wilcoxon sparse input data dtype must be real; complex sparse " + "data is not supported." + ) + raise TypeError(msg) + if cpsp.issparse(X) and data_dtype not in { + np.dtype(np.float32), + np.dtype(np.float64), + }: msg = ( - "Wilcoxon sparse input data dtype must be float32, float64, bool, " - f"or integer; got {data_dtype}." + "Wilcoxon device sparse input data dtype must be float32 or " + f"float64; got {data_dtype}." ) raise TypeError(msg) + if getattr(X, "format", None) in {"csr", "csc"}: + indices_dtype = np.dtype(X.indices.dtype) + indptr_dtype = np.dtype(X.indptr.dtype) + if indices_dtype != indptr_dtype: + msg = ( + "Wilcoxon sparse indices and indptr must have the same dtype; " + f"got indices={indices_dtype} and indptr={indptr_dtype}." + ) + raise TypeError(msg) + if indices_dtype not in {np.dtype(np.int32), np.dtype(np.int64)}: + msg = ( + "Wilcoxon sparse indices and indptr must be int32 or int64; " + f"got {indices_dtype}." + ) + raise TypeError(msg) - # Indices fit int32 (cells/genes < 2^31); only indptr (cumulative nnz) needs int64. - is_i64 = X.indptr.dtype == np.int64 - # The *_f64 binding only changes the host pointer dtype to accept float64 data - # without a host copy; it still ranks in float32 on-device (kernels cast InT -> - # float before the segmented sort). See _device_sparse_arrays_f32. - suffix = "" - if is_f64: - suffix += "_f64" - if is_i64: - suffix += "_i64" - # int64 column indices: if a native-int64 binding exists for this path, use - # it and pass the indices as-is. astype(int32) on int64 indices materializes - # a full copy of every nonzero (~nnz * 4 bytes, e.g. tens of GB on large - # matrices), so avoid it when the kernel can read int64 directly. - if X.indices.dtype == np.int64: - idx_fn = getattr(module, base_name + suffix + "_idx64", None) - if idx_fn is not None: - return idx_fn, data_arr, X.indices - fn = getattr(module, base_name + suffix) - indices_arr = X.indices.astype(np.int32, copy=False) - return fn, data_arr, indices_arr - - -def _device_sparse_arrays_f32(X): - """Cast device-sparse arrays for the Wilcoxon kernels. + +def _device_sparse_arrays(X): + """Prepare device-sparse arrays for the Wilcoxon kernels. Wilcoxon ranking sorts float32 keys on every path -- the sparse fast paths AND the dense fallback (``_ovr_dense_block_f32``); the CUB segmented @@ -417,117 +451,34 @@ def _device_sparse_arrays_f32(X): to spare the caller a pre-cast. """ data_dtype = np.dtype(X.data.dtype) - if data_dtype == np.float32 or data_dtype == np.float64: - pass - elif data_dtype.kind in {"b", "i", "u"}: - pass + if data_dtype == np.float32: + data = X.data + elif data_dtype == np.float64: + data = X.data.astype(cp.float32, copy=False) + elif data_dtype.kind == "c": + msg = ( + "Wilcoxon device sparse input data dtype must be real; complex " + "sparse data is not supported." + ) + raise TypeError(msg) else: msg = ( - "Wilcoxon device sparse input data dtype must be float32, float64, " - f"bool, or integer; got {data_dtype}." + "Wilcoxon device sparse input data dtype must be float32 or " + f"float64; got {data_dtype}." ) raise TypeError(msg) - data = X.data.astype(cp.float32, copy=False) - # Pass int64 indices natively to the *_idx64 kernels rather than a full nnz - # int32 copy (indices are only int64 when nnz > 2^31). int64 indices imply - # an int64 indptr, which those kernels require -- promote the (tiny) indptr - # if a hand-built matrix left it int32. Index values always fit int32; the - # CSC kernels downcast per-batch on-device where it's the sort value. + # Keep int64 index buffers native and let the nanobind overloads dispatch by + # dtype. Normal CuPy sparse matrices keep indices and indptr in lockstep. if X.indices.dtype == cp.int64: indices = X.indices - indptr = ( - X.indptr - if X.indptr.dtype == cp.int64 - else X.indptr.astype(cp.int64, copy=False) - ) + indptr = X.indptr else: indices = X.indices.astype(cp.int32, copy=False) - indptr = ( - X.indptr - if X.indptr.dtype == cp.int64 - else X.indptr.astype(cp.int32, copy=False) - ) + indptr = X.indptr.astype(cp.int32, copy=False) return data, indices, indptr -def _device_sparse_fn(module, base_name: str, indptr: cp.ndarray, indices: cp.ndarray): - """Select the device kernel binding (int64 indptr / int64 indices variants).""" - if indptr.dtype == cp.int64: - suffix = "_i64_idx64" if indices.dtype == cp.int64 else "_i64" - else: - suffix = "" - return getattr(module, base_name + suffix) - - -def _column_totals_for_host_matrix( - X, *, compute_nnz: bool -) -> tuple[cp.ndarray, cp.ndarray | None]: - n_cols = X.shape[1] - if isinstance(X, sp.spmatrix | sp.sparray): - data = np.asarray(X.data) - values = data.astype(np.float64, copy=False) - if X.format == "csc": - indptr = np.asarray(X.indptr) - counts = np.diff(indptr) - nonempty = counts > 0 - starts = indptr[:-1][nonempty] - sums = np.zeros(n_cols, dtype=np.float64) - if starts.size: - sums[nonempty] = np.add.reduceat(values, starts) - nnz = None - if compute_nnz: - nnz = np.zeros(n_cols, dtype=np.float64) - if starts.size: - nnz[nonempty] = np.add.reduceat( - (data != 0).astype(np.float64, copy=False), starts - ) - elif X.format == "csr": - indices = np.asarray(X.indices, dtype=np.intp) - sums = np.bincount(indices, weights=values, minlength=n_cols).astype( - np.float64, copy=False - ) - nnz = ( - np.bincount( - indices, - weights=(data != 0).astype(np.float64, copy=False), - minlength=n_cols, - ).astype(np.float64, copy=False) - if compute_nnz - else None - ) - else: - raise TypeError( - "Wilcoxon sparse input must be CSR or CSC; refusing hidden " - f"full-matrix conversion from {X.format!r}." - ) - elif isinstance(X, np.ndarray): - sums = X.sum(axis=0, dtype=np.float64) - nnz = (X != 0).sum(axis=0).astype(np.float64) if compute_nnz else None - else: - raise TypeError(f"Unsupported host matrix type: {type(X)}") - - total_sums = cp.asarray(sums.reshape(1, n_cols), dtype=cp.float64) - total_nnz = ( - cp.asarray(nnz.reshape(1, n_cols), dtype=cp.float64) - if nnz is not None - else None - ) - return total_sums, total_nnz - - -def _host_ovr_totals_if_needed( - X, - group_codes: np.ndarray, - n_groups: int, - *, - compute_nnz: bool, -) -> tuple[cp.ndarray | None, cp.ndarray | None]: - if not np.any(group_codes == n_groups): - return None, None - return _column_totals_for_host_matrix(X, compute_nnz=compute_nnz) - - def wilcoxon( rg: _RankGenes, *, @@ -540,8 +491,9 @@ def wilcoxon( # Host dense OVR and OVO stream column windows from host. Already-device # dense OVO still uses the device-resident tiered planner. # Aggregate if on GPU, else defer to chunks. - rg._basic_stats() X = rg.X + _validate_wilcoxon_sparse_dtype(X) + rg._basic_stats() n_cells, n_total_genes = rg.X.shape group_sizes = rg.group_sizes @@ -569,21 +521,39 @@ def wilcoxon( ) -def _wilcoxon_vs_rest( - rg: _RankGenes, - X, - n_cells: int, - n_total_genes: int, - group_sizes: NDArray, - *, - tie_correct: bool, - use_continuity: bool, - chunk_size: int | None, - return_u_values: bool, -) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs rest of cells.""" - n_groups = len(rg.groups_order) +def _host_sparse_format(X, *, sparse_negative_fallback: bool) -> str | None: + if sparse_negative_fallback or not isinstance(X, sp.spmatrix | sp.sparray): + return None + if X.format not in {"csr", "csc"}: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + return X.format + +def _device_sparse_format(X, *, sparse_negative_fallback: bool) -> str | None: + if sparse_negative_fallback: + return None + if cpsp.isspmatrix_csc(X): + return "csc" + if cpsp.isspmatrix_csr(X): + return "csr" + return None + + +def _host_dense_matrix(X) -> np.ndarray | None: + if not isinstance(X, np.ndarray): + return None + matrix = X + if matrix.dtype.kind != "f" or matrix.dtype.itemsize < 4: + return np.asarray(matrix, dtype=np.float32, order="F") + if matrix.flags.c_contiguous or matrix.flags.f_contiguous: + return matrix + return np.asfortranarray(matrix) + + +def _warn_small_ovr_groups(rg: _RankGenes, group_sizes: NDArray, n_cells: int) -> None: for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: @@ -594,246 +564,416 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - host_sparse = ( - isinstance(X, sp.spmatrix | sp.sparray) and not rg._sparse_negative_fallback + +def _warn_small_ovo_groups( + rg: _RankGenes, ctx: _OvoContext, group_sizes: NDArray +) -> None: + small_groups = [ + str(rg.groups_order[group_index]) + for group_index in ctx.test_group_indices + if int(group_sizes[group_index]) <= MIN_GROUP_SIZE_WARNING + ] + if ctx.n_ref > MIN_GROUP_SIZE_WARNING and not small_groups: + return + parts = [] + if small_groups: + parts.append( + f"{len(small_groups)} test group(s) have size " + f"<= {MIN_GROUP_SIZE_WARNING} (first few: " + f"{', '.join(small_groups[:5])}" + f"{'...' if len(small_groups) > 5 else ''})" + ) + if ctx.n_ref <= MIN_GROUP_SIZE_WARNING: + parts.append(f"reference has size {ctx.n_ref}") + warnings.warn( + f"Small groups detected: {'; '.join(parts)}. normal approximation " + "of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=4, + ) + + +def _build_ovo_context(rg: _RankGenes, group_sizes: NDArray) -> _OvoContext: + codes = rg.group_codes + n_groups = len(rg.groups_order) + ireference = int(rg.ireference) + n_ref = int(group_sizes[ireference]) + ref_row_ids = np.flatnonzero(codes == ireference).astype(np.int32, copy=False) + test_group_indices = [i for i in range(n_groups) if i != ireference] + + offsets = [0] + row_id_parts = [] + for group_index in test_group_indices: + group_rows = np.flatnonzero(codes == group_index).astype(np.int32, copy=False) + row_id_parts.append(group_rows) + offsets.append(offsets[-1] + int(group_rows.size)) + + all_grp_row_ids = ( + np.concatenate(row_id_parts).astype(np.int32, copy=False) + if row_id_parts + else np.empty(0, dtype=np.int32) ) - if host_sparse: - if X.format not in {"csr", "csc"}: - raise TypeError( - "Wilcoxon sparse input must be CSR or CSC; refusing hidden " - f"full-matrix conversion from {X.format!r}." - ) + offsets_np = np.asarray(offsets, dtype=np.int32) + test_sizes = cp.asarray( + group_sizes[np.asarray(test_group_indices, dtype=np.intp)].astype( + np.float64, copy=False + ) + ) + return _OvoContext( + codes=codes, + n_groups=n_groups, + ireference=ireference, + n_ref=n_ref, + ref_row_ids=ref_row_ids, + test_group_indices=test_group_indices, + all_grp_row_ids=all_grp_row_ids, + offsets_np=offsets_np, + offsets_gpu=cp.asarray(offsets_np), + n_all_grp=int(all_grp_row_ids.size), + n_test=len(test_group_indices), + test_sizes=test_sizes, + ) + - group_codes = rg.group_codes.astype(np.int32, copy=False) - group_sizes_np = group_sizes.astype(np.float64, copy=False) - group_sizes_dev = cp.asarray(group_sizes_np, dtype=cp.float64) - rest_sizes = n_cells - group_sizes_dev - compute_nnz = rg.comp_pts - - rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) - tie_corr = cp.ones(n_total_genes, dtype=cp.float64) - group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) - group_nnz = cp.empty( - (n_groups, n_total_genes) if compute_nnz else (1, 1), - dtype=cp.float64, +def _finish_ovo_sparse_stats( + rg: _RankGenes, + ctx: _OvoContext, + group_sums: cp.ndarray, + group_nnz: cp.ndarray, + group_sizes: NDArray, +) -> cp.ndarray | None: + if not rg._compute_stats_in_chunks: + return None + if rg._store_wilcoxon_gpu_result and not rg.comp_pts: + rg._compute_stats_in_chunks = False + return _ovo_logfoldchanges_from_sums( + rg, + group_sums, + ctx.test_sizes, + ctx.n_ref, ) + _fill_ovo_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes=group_sizes, + test_group_indices=ctx.test_group_indices, + n_ref=ctx.n_ref, + ) + return None - if X.format == "csc": - csc = X - if not csc.has_sorted_indices: - csc = csc.copy() - csc.sort_indices() - csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovr_sparse_csc_host", csc - ) - csc_host_fn( - data_arr, - indices_arr, - csc.indptr, - group_codes, - group_sizes_np, - rank_sums, - tie_corr, - group_sums, - group_nnz, - n_rows=n_cells, - n_cols=n_total_genes, - n_groups=n_groups, - compute_tie_corr=tie_correct, - compute_nnz=compute_nnz, - sub_batch_cols=OVR_HOST_CSC_SUB_BATCH, - ) - else: - csr = X - if not csr.has_sorted_indices: - csr = csr.copy() - csr.sort_indices() - csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovr_sparse_csr_host", csr - ) - csr_host_fn( - data_arr, - indices_arr, - csr.indptr, - group_codes, - group_sizes_np, - rank_sums, - tie_corr, - group_sums, - group_nnz, - n_rows=n_cells, - n_cols=n_total_genes, - n_groups=n_groups, - compute_tie_corr=tie_correct, - compute_nnz=compute_nnz, - sub_batch_cols=OVR_HOST_CSR_SUB_BATCH, - ) - if rg._compute_stats_in_chunks: - total_sums, total_nnz = _host_ovr_totals_if_needed( - X, - group_codes, - n_groups, - compute_nnz=compute_nnz, - ) - _fill_basic_stats_from_accumulators( - rg, - group_sums, - group_nnz, - group_sizes_np, - n_cells=n_cells, - total_sums=total_sums, - total_nnz=total_nnz, - ) +def _finish_ovo_dense_stats( + rg: _RankGenes, + ctx: _OvoContext, + group_sums: cp.ndarray, + group_sum_sq: cp.ndarray, + group_nnz: cp.ndarray, + *, + group_sizes: NDArray, +) -> cp.ndarray | None: + if not rg._compute_stats_in_chunks: + return None + if rg._store_wilcoxon_gpu_result and not rg.comp_pts: + rg._compute_stats_in_chunks = False + return _ovo_logfoldchanges_from_sums( + rg, + group_sums, + ctx.test_sizes, + ctx.n_ref, + ) + _fill_ovo_dense_stats_from_accumulators( + rg, + group_sums, + group_sum_sq, + group_nnz, + group_sizes=group_sizes, + test_group_indices=ctx.test_group_indices, + n_ref=ctx.n_ref, + ) + return None + - return _finish_ovr( +def _run_ovr_host_sparse( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + sparse_format = _host_sparse_format( + X, sparse_negative_fallback=rg._sparse_negative_fallback + ) + if sparse_format is None: + return None + + n_groups = len(rg.groups_order) + group_codes = rg.group_codes.astype(np.int32, copy=False) + group_sizes_np = group_sizes.astype(np.float64, copy=False) + group_sizes_dev = cp.asarray(group_sizes_np, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + compute_nnz = rg.comp_pts + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_nnz = cp.empty( + (n_groups, n_total_genes) if compute_nnz else (1, 1), + dtype=cp.float64, + ) + compute_totals = bool( + rg._compute_stats_in_chunks and np.any(group_codes == n_groups) + ) + total_sums = cp.empty( + (1, n_total_genes) if compute_totals else (1, 1), + dtype=cp.float64, + ) + total_nnz = cp.empty( + (1, n_total_genes) if (compute_totals and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + + if isinstance(X, sp.spmatrix | sp.sparray) and X.format == "csc": + X.sort_indices() + _wcs.ovr_sparse_csc_host( + _host_sparse_data_array(X), + X.indices, + X.indptr, + group_codes, + group_sizes_np, rank_sums, - group_sizes_dev, - rest_sizes, - n_cells, tie_corr, - use_continuity=use_continuity, - return_u_values=return_u_values, + group_sums, + group_nnz, + total_sums, + total_nnz, + n_rows=n_cells, + n_cols=n_total_genes, n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + compute_totals=compute_totals, + sub_batch_cols=OVR_HOST_CSC_SUB_BATCH, ) - - if ( - cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X) - ) and not rg._sparse_negative_fallback: - data, indices, indptr = _device_sparse_arrays_f32(X) - group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) - group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) - rest_sizes = n_cells - group_sizes_dev - rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) - tie_corr = cp.ones(n_total_genes, dtype=cp.float64) - if cpsp.isspmatrix_csc(X): - _device_sparse_fn(_wcs, "ovr_sparse_csc_device", indptr, indices)( - data, - indices, - indptr, - group_codes_gpu, - group_sizes_dev, - rank_sums, - tie_corr, - n_rows=n_cells, - n_cols=n_total_genes, - n_groups=n_groups, - compute_tie_corr=tie_correct, - sub_batch_cols=OVR_DEVICE_CSC_SUB_BATCH, - ) - else: - sparse_X = X - if not sparse_X.has_sorted_indices: - sparse_X = sparse_X.copy() - sparse_X.sort_indices() - data, indices, indptr = _device_sparse_arrays_f32(sparse_X) - _device_sparse_fn(_wcs, "ovr_sparse_csr_device", indptr, indices)( - data, - indices, - indptr, - group_codes_gpu, - group_sizes_dev, - rank_sums, - tie_corr, - n_rows=n_cells, - n_cols=n_total_genes, - n_groups=n_groups, - compute_tie_corr=tie_correct, - sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, - ) - - return _finish_ovr( + else: + X.sort_indices() + _wcs.ovr_sparse_csr_host( + _host_sparse_data_array(X), + X.indices, + X.indptr, + group_codes, + group_sizes_np, rank_sums, - group_sizes_dev, - rest_sizes, - n_cells, tie_corr, - use_continuity=use_continuity, - return_u_values=return_u_values, + group_sums, + group_nnz, + total_sums, + total_nnz, + n_rows=n_cells, + n_cols=n_total_genes, n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + compute_totals=compute_totals, + sub_batch_cols=OVR_HOST_CSR_SUB_BATCH, ) - group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + if rg._compute_stats_in_chunks: + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes_np, + n_cells=n_cells, + total_sums=total_sums if compute_totals else None, + total_nnz=total_nnz if compute_totals and compute_nnz else None, + ) + + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) + +def _run_ovr_device_sparse( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + sparse_format = _device_sparse_format( + X, sparse_negative_fallback=rg._sparse_negative_fallback + ) + if sparse_format is None: + return None + + X.sort_indices() + data, indices, indptr = _device_sparse_arrays(X) + n_groups = len(rg.groups_order) + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev - - # Host dense: stream column chunks from host with the CSC-style pipeline - # (per-batch H2D overlapping the rank), instead of moving the whole array to - # the GPU. Ranking + group sums/nnz come back in one streamed pass. - if isinstance(X, np.ndarray): - # Only float32/float64 host bindings exist; cast int/bool/uint/float16 - # to float32 (mirrors the sparse paths) rather than raising a TypeError. - if X.dtype.kind != "f" or X.dtype.itemsize < 4: - X = X.astype(np.float32) - if X.flags.f_contiguous: - buf, f_order = X.ravel(order="K"), True - elif X.flags.c_contiguous: - buf, f_order = X.ravel(order="K"), False - else: - buf, f_order = np.ascontiguousarray(X).ravel(order="K"), False - compute_nnz = rg.comp_pts - compute_stats = rg._compute_stats_in_chunks - rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) - tie_corr = ( - cp.empty(n_total_genes, dtype=cp.float64) - if tie_correct - else cp.ones(n_total_genes, dtype=cp.float64) - ) - stats_shape = (n_groups, n_total_genes) if compute_stats else (1, 1) - group_sums = cp.empty(stats_shape, dtype=cp.float64) - group_nnz = cp.empty( - (n_groups, n_total_genes) if (compute_stats and compute_nnz) else (1, 1), - dtype=cp.float64, - ) - _wc.ovr_rank_dense_host_streaming( - buf, + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + + if sparse_format == "csc": + _wcs.ovr_sparse_csc_device( + data, + indices, + indptr, group_codes_gpu, + group_sizes_dev, rank_sums, tie_corr, - group_sums, - group_nnz, n_rows=n_cells, n_cols=n_total_genes, n_groups=n_groups, - f_order=f_order, compute_tie_corr=tie_correct, - compute_nnz=compute_stats and compute_nnz, - compute_stats=compute_stats, - sub_batch_cols=OVR_DENSE_SUB_BATCH, + sub_batch_cols=OVR_DEVICE_CSC_SUB_BATCH, ) - if compute_stats: - total_sums, total_nnz = _host_ovr_totals_if_needed( - X, rg.group_codes, n_groups, compute_nnz=compute_nnz - ) - _fill_basic_stats_from_accumulators( - rg, - group_sums, - group_nnz, - group_sizes, - n_cells=n_cells, - total_sums=total_sums, - total_nnz=total_nnz, - ) - return _finish_ovr( - rank_sums, + else: + _wcs.ovr_sparse_csr_device( + data, + indices, + indptr, + group_codes_gpu, group_sizes_dev, - rest_sizes, - n_cells, + rank_sums, tie_corr, - use_continuity=use_continuity, - return_u_values=return_u_values, + n_rows=n_cells, + n_cols=n_total_genes, n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, ) - chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) + + +def _run_ovr_host_dense( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + matrix = _host_dense_matrix(X) + if matrix is None: + return None + n_groups = len(rg.groups_order) + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + compute_nnz = rg.comp_pts + compute_stats = rg._compute_stats_in_chunks + compute_totals = bool(compute_stats and np.any(rg.group_codes == n_groups)) + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = ( + cp.empty(n_total_genes, dtype=cp.float64) + if tie_correct + else cp.ones(n_total_genes, dtype=cp.float64) + ) + stats_shape = (n_groups, n_total_genes) if compute_stats else (1, 1) + group_sums = cp.empty(stats_shape, dtype=cp.float64) + group_nnz = cp.empty( + (n_groups, n_total_genes) if (compute_stats and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + total_sums = cp.empty( + (1, n_total_genes) if compute_totals else (1, 1), + dtype=cp.float64, + ) + total_nnz = cp.empty( + (1, n_total_genes) if (compute_totals and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + _wc.ovr_rank_dense_host_streaming( + matrix, + group_codes_gpu, + rank_sums, + tie_corr, + group_sums, + group_nnz, + total_sums, + total_nnz, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_nnz=compute_stats and compute_nnz, + compute_stats=compute_stats, + compute_totals=compute_totals, + sub_batch_cols=OVR_DENSE_SUB_BATCH, + ) + if compute_stats: + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes, + n_cells=n_cells, + total_sums=total_sums if compute_totals else None, + total_nnz=total_nnz if compute_totals and compute_nnz else None, + ) + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) + +def _run_ovr_dense_chunks( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]]: + n_groups = len(rg.groups_order) + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev all_scores: dict[int, list] = {i: [] for i in range(n_groups)} all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) - if rg._compute_stats_in_chunks: block = _get_column_block(X, start, stop) rg._accumulate_chunk_stats_vs_rest( @@ -876,7 +1016,6 @@ def _wilcoxon_vs_rest( use_continuity=use_continuity, return_u_values=return_u_values, ) - scores_host = scores.get() p_host = p_values.get() @@ -890,9 +1029,10 @@ def _wilcoxon_vs_rest( ] -def _wilcoxon_with_reference( +def _wilcoxon_vs_rest( rg: _RankGenes, X, + n_cells: int, n_total_genes: int, group_sizes: NDArray, *, @@ -901,357 +1041,320 @@ def _wilcoxon_with_reference( chunk_size: int | None, return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: all selected groups vs a specific reference group.""" - codes = rg.group_codes - n_groups = len(rg.groups_order) - ireference = rg.ireference - n_ref = int(group_sizes[ireference]) - ref_row_ids = np.flatnonzero(codes == ireference).astype(np.int32, copy=False) - - test_group_indices = [i for i in range(n_groups) if i != ireference] - if not test_group_indices: - return [] - - offsets = [0] - row_id_parts = [] - small_groups = [] - for group_index in test_group_indices: - group_rows = np.flatnonzero(codes == group_index).astype(np.int32, copy=False) - row_id_parts.append(group_rows) - offsets.append(offsets[-1] + int(group_rows.size)) - if int(group_sizes[group_index]) <= MIN_GROUP_SIZE_WARNING: - small_groups.append(str(rg.groups_order[group_index])) - - if n_ref <= MIN_GROUP_SIZE_WARNING or small_groups: - parts = [] - if small_groups: - parts.append( - f"{len(small_groups)} test group(s) have size " - f"<= {MIN_GROUP_SIZE_WARNING} (first few: " - f"{', '.join(small_groups[:5])}" - f"{'...' if len(small_groups) > 5 else ''})" - ) - if n_ref <= MIN_GROUP_SIZE_WARNING: - parts.append(f"reference has size {n_ref}") - warnings.warn( - f"Small groups detected: {'; '.join(parts)}. normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, - ) - - all_grp_row_ids = ( - np.concatenate(row_id_parts).astype(np.int32, copy=False) - if row_id_parts - else np.empty(0, dtype=np.int32) - ) - offsets_np = np.asarray(offsets, dtype=np.int32) - offsets_gpu = cp.asarray(offsets_np) - n_all_grp = int(all_grp_row_ids.size) - n_test = len(test_group_indices) - test_sizes = cp.asarray( - group_sizes[np.asarray(test_group_indices, dtype=np.intp)].astype( - np.float64, copy=False - ) - ) - - host_sparse = ( - isinstance(X, sp.spmatrix | sp.sparray) and not rg._sparse_negative_fallback - ) - if host_sparse: - if X.format not in {"csr", "csc"}: - raise TypeError( - "Wilcoxon sparse input must be CSR or CSC; refusing hidden " - f"full-matrix conversion from {X.format!r}." - ) - - # zeros, not empty: an all-empty test batch (n_all_grp == 0) - # short-circuits the kernel without writing rank_sums. - rank_sums = cp.zeros((n_test, n_total_genes), dtype=cp.float64) - tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) - n_groups_stats = n_test + 1 - compute_sums = rg._compute_stats_in_chunks - compute_nnz = rg.comp_pts - group_sums = cp.empty( - (n_groups_stats, n_total_genes) - if (compute_sums or X.format == "csc") - else (1,), - dtype=cp.float64, - ) - group_nnz = cp.empty( - (n_groups_stats, n_total_genes) if compute_nnz else (1,), - dtype=cp.float64, - ) - - stats_code_lookup = np.full(n_groups + 1, n_groups_stats, dtype=np.int32) - test_group_indices_np = np.asarray(test_group_indices, dtype=np.intp) - stats_code_lookup[test_group_indices_np] = np.arange(n_test, dtype=np.int32) - stats_code_lookup[ireference] = n_test - stats_codes = stats_code_lookup[codes] - - if X.format == "csc": - csc = X - if not csc.has_sorted_indices: - csc = csc.copy() - csc.sort_indices() - ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) - ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) - grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) - grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) - csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovo_streaming_csc_host", csc - ) - csc_host_fn( - data_arr, - indices_arr, - csc.indptr, - ref_row_map, - grp_row_map, - offsets_np, - stats_codes, - rank_sums, - tie_corr_arr, - group_sums, - group_nnz, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_rows=X.shape[0], - n_cols=n_total_genes, - n_groups=n_test, - n_groups_stats=n_groups_stats, - compute_tie_corr=tie_correct, - compute_nnz=compute_nnz, - sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, - ) - else: - csr = X - # Host CSR gather scans each row's native index list and tolerates - # unsorted row indices; avoid a full CSR copy just to sort. - csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovo_streaming_csr_host", csr - ) - csr_host_fn( - data_arr, - indices_arr, - csr.indptr, - ref_row_ids.astype(np.int32, copy=False), - all_grp_row_ids.astype(np.int32, copy=False), - offsets_np, - rank_sums, - tie_corr_arr, - group_sums, - group_nnz, - n_full_rows=X.shape[0], - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_total_genes, - n_test=n_test, - n_groups_stats=n_groups_stats, - compute_tie_corr=tie_correct, - compute_nnz=compute_nnz, - compute_sums=compute_sums, - sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, - ) - - logfoldchanges_gpu = None - if rg._compute_stats_in_chunks: - if rg._store_wilcoxon_gpu_result and not rg.comp_pts: - logfoldchanges_gpu = _ovo_logfoldchanges_from_sums( - rg, - group_sums, - test_sizes, - n_ref, - ) - rg._compute_stats_in_chunks = False - else: - _fill_ovo_stats_from_accumulators( - rg, - group_sums, - group_nnz, - group_sizes=group_sizes, - test_group_indices=test_group_indices, - n_ref=n_ref, - ) - - return _finish_ovo( - rank_sums, - test_sizes, - n_ref, - tie_corr_arr, + """Wilcoxon test: each group vs rest of cells.""" + _warn_small_ovr_groups(rg, group_sizes, n_cells) + for runner in ( + _run_ovr_host_sparse, + _run_ovr_device_sparse, + _run_ovr_host_dense, + ): + result = runner( + rg, + X, + n_cells, + n_total_genes, + group_sizes, tie_correct=tie_correct, use_continuity=use_continuity, return_u_values=return_u_values, - rg=rg, - test_group_indices=test_group_indices, - logfoldchanges_gpu=logfoldchanges_gpu, ) + if result is not None: + return result + return _run_ovr_dense_chunks( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) - if ( - cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X) - ) and not rg._sparse_negative_fallback: - sparse_X = X - if cpsp.isspmatrix_csr(sparse_X) and not sparse_X.has_sorted_indices: - sparse_X = sparse_X.copy() - sparse_X.sort_indices() - data, indices, indptr = _device_sparse_arrays_f32(sparse_X) - # offsets_gpu (built once above as int32) is reused here. - # zeros, not empty: an all-empty test batch (n_all_grp == 0) - # short-circuits the kernel without writing rank_sums. - rank_sums = cp.zeros((n_test, n_total_genes), dtype=cp.float64) - tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) - - if cpsp.isspmatrix_csc(sparse_X): - ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) - ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) - grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) - grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) - _device_sparse_fn(_wcs, "ovo_streaming_csc_device", indptr, indices)( - data, - indices, - indptr, - cp.asarray(ref_row_map), - cp.asarray(grp_row_map), - offsets_gpu, - rank_sums, - tie_corr_arr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_total_genes, - n_groups=n_test, - compute_tie_corr=tie_correct, - sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, - ) - else: - _device_sparse_fn(_wcs, "ovo_streaming_csr_device", indptr, indices)( - data, - indices, - indptr, - cp.asarray(ref_row_ids, dtype=cp.int32), - cp.asarray(all_grp_row_ids, dtype=cp.int32), - offsets_gpu, - rank_sums, - tie_corr_arr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_total_genes, - n_groups=n_test, - compute_tie_corr=tie_correct, - sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, - ) - return _finish_ovo( +def _run_ovo_host_sparse( + rg: _RankGenes, + X, + ctx: _OvoContext, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + sparse_format = _host_sparse_format( + X, sparse_negative_fallback=rg._sparse_negative_fallback + ) + if sparse_format is None: + return None + + rank_sums = cp.zeros((ctx.n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((ctx.n_test, n_total_genes), dtype=cp.float64) + n_groups_stats = ctx.n_test + 1 + compute_sums = rg._compute_stats_in_chunks + compute_nnz = rg.comp_pts + group_sums = cp.empty( + (n_groups_stats, n_total_genes) + if (compute_sums or sparse_format == "csc") + else (1,), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups_stats, n_total_genes) if compute_nnz else (1,), + dtype=cp.float64, + ) + stats_code_lookup = np.full(ctx.n_groups + 1, n_groups_stats, dtype=np.int32) + test_group_indices_np = np.asarray(ctx.test_group_indices, dtype=np.intp) + stats_code_lookup[test_group_indices_np] = np.arange(ctx.n_test, dtype=np.int32) + stats_code_lookup[ctx.ireference] = ctx.n_test + stats_codes = stats_code_lookup[ctx.codes] + + if sparse_format == "csc": + X.sort_indices() + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ctx.ref_row_ids] = np.arange(ctx.n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[ctx.all_grp_row_ids] = np.arange(ctx.n_all_grp, dtype=np.int32) + _wcs.ovo_streaming_csc_host( + _host_sparse_data_array(X), + X.indices, + X.indptr, + ref_row_map, + grp_row_map, + ctx.offsets_np, + stats_codes, rank_sums, - test_sizes, - n_ref, tie_corr_arr, - tie_correct=tie_correct, - use_continuity=use_continuity, - return_u_values=return_u_values, - rg=rg, - test_group_indices=test_group_indices, - logfoldchanges_gpu=None, - ) - - if isinstance(X, np.ndarray): - if X.dtype.kind != "f" or X.dtype.itemsize < 4: - X = X.astype(np.float32) - if X.flags.f_contiguous: - buf, f_order = X.ravel(order="K"), True - elif X.flags.c_contiguous: - buf, f_order = X.ravel(order="K"), False - else: - buf, f_order = np.ascontiguousarray(X).ravel(order="K"), False - dense_sub_batch_cols = ( - _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) - if chunk_size is not None - else OVO_DENSE_TIERED_SUB_BATCH - ) - - rank_sums = cp.zeros((n_test, n_total_genes), dtype=cp.float64) - tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) - compute_stats = rg._compute_stats_in_chunks - compute_nnz = compute_stats and rg.comp_pts - n_groups_stats = n_test + 1 - stats_shape = (n_groups_stats, n_total_genes) if compute_stats else (1, 1) - group_sums = cp.empty(stats_shape, dtype=cp.float64) - group_sum_sq = cp.empty(stats_shape, dtype=cp.float64) - group_nnz = cp.empty( - stats_shape if compute_nnz else (1, 1), - dtype=cp.float64, + group_sums, + group_nnz, + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, + n_rows=X.shape[0], + n_cols=n_total_genes, + n_groups=ctx.n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, ) - - _wc.ovo_rank_dense_host_streaming( - buf, - ref_row_ids, - all_grp_row_ids, - offsets_np, + else: + X.sort_indices() + _wcs.ovo_streaming_csr_host( + _host_sparse_data_array(X), + X.indices, + X.indptr, + ctx.ref_row_ids, + ctx.all_grp_row_ids, + ctx.offsets_np, rank_sums, tie_corr_arr, group_sums, - group_sum_sq, group_nnz, n_full_rows=X.shape[0], + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, n_cols=n_total_genes, - n_groups=n_test, - f_order=f_order, + n_test=ctx.n_test, + n_groups_stats=n_groups_stats, compute_tie_corr=tie_correct, compute_nnz=compute_nnz, - compute_stats=compute_stats, - sub_batch_cols=dense_sub_batch_cols, + compute_sums=compute_sums, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, ) - logfoldchanges_gpu = None - if compute_stats: - if rg._store_wilcoxon_gpu_result and not rg.comp_pts: - logfoldchanges_gpu = _ovo_logfoldchanges_from_sums( - rg, - group_sums, - test_sizes, - n_ref, - ) - rg._compute_stats_in_chunks = False - else: - _fill_ovo_dense_stats_from_accumulators( - rg, - group_sums, - group_sum_sq, - group_nnz, - group_sizes=group_sizes, - test_group_indices=test_group_indices, - n_ref=n_ref, - ) - - return _finish_ovo( + logfoldchanges_gpu = _finish_ovo_sparse_stats( + rg, ctx, group_sums, group_nnz, group_sizes + ) + return _finish_ovo( + rank_sums, + ctx.test_sizes, + ctx.n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=ctx.test_group_indices, + logfoldchanges_gpu=logfoldchanges_gpu, + ) + + +def _run_ovo_device_sparse( + rg: _RankGenes, + X, + ctx: _OvoContext, + n_total_genes: int, + _group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + sparse_format = _device_sparse_format( + X, sparse_negative_fallback=rg._sparse_negative_fallback + ) + if sparse_format is None: + return None + + if isinstance(X, cpsp.spmatrix) and X.format == "csr": + X.sort_indices() + data, indices, indptr = _device_sparse_arrays(X) + rank_sums = cp.zeros((ctx.n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((ctx.n_test, n_total_genes), dtype=cp.float64) + + if sparse_format == "csc": + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ctx.ref_row_ids] = np.arange(ctx.n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[ctx.all_grp_row_ids] = np.arange(ctx.n_all_grp, dtype=np.int32) + _wcs.ovo_streaming_csc_device( + data, + indices, + indptr, + cp.asarray(ref_row_map), + cp.asarray(grp_row_map), + ctx.offsets_gpu, rank_sums, - test_sizes, - n_ref, tie_corr_arr, - tie_correct=tie_correct, - use_continuity=use_continuity, - return_u_values=return_u_values, - rg=rg, - test_group_indices=test_group_indices, - logfoldchanges_gpu=logfoldchanges_gpu, + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, + n_cols=n_total_genes, + n_groups=ctx.n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) + else: + _wcs.ovo_streaming_csr_device( + data, + indices, + indptr, + cp.asarray(ctx.ref_row_ids, dtype=cp.int32), + cp.asarray(ctx.all_grp_row_ids, dtype=cp.int32), + ctx.offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, + n_cols=n_total_genes, + n_groups=ctx.n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, ) - chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + return _finish_ovo( + rank_sums, + ctx.test_sizes, + ctx.n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=ctx.test_group_indices, + logfoldchanges_gpu=None, + ) - scores_host = np.empty((n_test, n_total_genes), dtype=np.float64) - pvals_host = np.empty((n_test, n_total_genes), dtype=np.float64) + +def _run_ovo_host_dense( + rg: _RankGenes, + X, + ctx: _OvoContext, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + matrix = _host_dense_matrix(X) + if matrix is None: + return None + dense_sub_batch_cols = ( + _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + if chunk_size is not None + else OVO_DENSE_TIERED_SUB_BATCH + ) + rank_sums = cp.zeros((ctx.n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((ctx.n_test, n_total_genes), dtype=cp.float64) + compute_stats = rg._compute_stats_in_chunks + compute_nnz = compute_stats and rg.comp_pts + n_groups_stats = ctx.n_test + 1 + stats_shape = (n_groups_stats, n_total_genes) if compute_stats else (1, 1) + group_sums = cp.empty(stats_shape, dtype=cp.float64) + group_sum_sq = cp.empty(stats_shape, dtype=cp.float64) + group_nnz = cp.empty( + stats_shape if compute_nnz else (1, 1), + dtype=cp.float64, + ) + _wc.ovo_rank_dense_host_streaming( + matrix, + ctx.ref_row_ids, + ctx.all_grp_row_ids, + ctx.offsets_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sum_sq, + group_nnz, + n_groups=ctx.n_test, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + compute_stats=compute_stats, + sub_batch_cols=dense_sub_batch_cols, + ) + logfoldchanges_gpu = _finish_ovo_dense_stats( + rg, + ctx, + group_sums, + group_sum_sq, + group_nnz, + group_sizes=group_sizes, + ) + return _finish_ovo( + rank_sums, + ctx.test_sizes, + ctx.n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=ctx.test_group_indices, + logfoldchanges_gpu=logfoldchanges_gpu, + ) + + +def _run_ovo_dense_chunks( + rg: _RankGenes, + X, + ctx: _OvoContext, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]]: + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + scores_host = np.empty((ctx.n_test, n_total_genes), dtype=np.float64) + pvals_host = np.empty((ctx.n_test, n_total_genes), dtype=np.float64) for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) n_cols = stop - start - - ref_block = _ovo_dense_block(X, ref_row_ids, start, stop) - grp_block = _ovo_dense_block(X, all_grp_row_ids, start, stop) + ref_block = _ovo_dense_block(X, ctx.ref_row_ids, start, stop) + grp_block = _ovo_dense_block(X, ctx.all_grp_row_ids, start, stop) _fill_ovo_chunk_stats( rg, ref_block, grp_block, - offsets=offsets_np, - test_group_indices=test_group_indices, + offsets=ctx.offsets_np, + test_group_indices=ctx.test_group_indices, start=start, stop=stop, group_sizes=group_sizes, @@ -1259,40 +1362,82 @@ def _wilcoxon_with_reference( ref_f32 = cp.asarray(ref_block, dtype=cp.float32, order="F") grp_f32 = cp.asarray(grp_block, dtype=cp.float32, order="F") - # zeros/ones, not empty: an all-empty test batch (n_all_grp == 0) - # short-circuits the kernel, leaving these outputs unwritten. - rank_sums = cp.zeros((n_test, n_cols), dtype=cp.float64) - tie_corr = cp.ones((n_test, n_cols), dtype=cp.float64) - + rank_sums = cp.zeros((ctx.n_test, n_cols), dtype=cp.float64) + tie_corr = cp.ones((ctx.n_test, n_cols), dtype=cp.float64) _wc.ovo_rank_dense_tiered_unsorted_ref( ref_f32, grp_f32, - offsets_gpu, + ctx.offsets_gpu, rank_sums, tie_corr, - n_ref=n_ref, - n_all_grp=n_all_grp, + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, n_cols=n_cols, - n_groups=n_test, + n_groups=ctx.n_test, compute_tie_corr=tie_correct, sub_batch_cols=OVO_DENSE_TIERED_SUB_BATCH, stream=cp.cuda.get_current_stream().ptr, ) - scores, p_values = _ovo_z_pvals( rank_sums, - test_sizes, - n_ref, + ctx.test_sizes, + ctx.n_ref, tie_corr, tie_correct=tie_correct, use_continuity=use_continuity, return_u_values=return_u_values, ) - scores_host[:, start:stop] = scores.get() pvals_host[:, start:stop] = p_values.get() return [ (group_index, scores_host[slot], pvals_host[slot]) - for slot, group_index in enumerate(test_group_indices) + for slot, group_index in enumerate(ctx.test_group_indices) ] + + +def _wilcoxon_with_reference( + rg: _RankGenes, + X, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]]: + """Wilcoxon test: all selected groups vs a specific reference group.""" + ctx = _build_ovo_context(rg, group_sizes) + if ctx.n_test == 0: + return [] + _warn_small_ovo_groups(rg, ctx, group_sizes) + for runner, extra in ( + (_run_ovo_host_sparse, {}), + (_run_ovo_device_sparse, {}), + (_run_ovo_host_dense, {"chunk_size": chunk_size}), + ): + result = runner( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + **extra, + ) + if result is not None: + return result + return _run_ovo_dense_chunks( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 33c9f2a4..6aba2740 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -85,8 +85,8 @@ def test_rank_genes_groups_sparse_negative_values_fallback(method, reference, fm @pytest.mark.parametrize("reference", ["rest", "1"]) def test_device_sparse_int64_indptr_matches_scanpy(layout, reference): # Real int64 indptr only occurs at nnz > 2^31 (unallocatable in CI). cupy - # >= 14.1 preserves an explicitly promoted int64 indptr, so a small matrix - # promoted to int64 drives the *_i64 device kernels through the public API. + # >= 14.1 preserves explicitly promoted int64 indices/indptr, so a small + # matrix promoted to int64 drives the int64 device overloads. rng = np.random.default_rng(0) dense = np.abs(rng.standard_normal((150, 8))).astype(np.float32) dense[dense < 0.5] = 0.0 @@ -731,7 +731,8 @@ def _make_sized_groups_adata(group_sizes, n_genes, seed=0): # <=~70 (all MEDIUM), so LARGE/HUGE are otherwise never exercised. These force a # single large test group. @pytest.mark.parametrize( - "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] + "fmt", + ["numpy_dense", "cupy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"], ) @pytest.mark.parametrize("tie_correct", [False, True]) @pytest.mark.parametrize("big", [700, 3000], ids=["large_fused", "huge_cub"]) @@ -886,10 +887,10 @@ def run(arr): ) -# F-contiguous host-dense numpy hits the f_order=True branch of the host- +# F-contiguous host-dense numpy hits the F-order nanobind overload of the host # streaming launcher: float32 -> the reinterpret-cast fast path (no cast kernel), # float64 -> dense_block_to_f32_kernel's identity branch. Every numpy_dense -# fixture elsewhere is C-order, so this is the only coverage of that branch. +# fixture elsewhere is C-order, so this is the only coverage of that overload. # AnnData preserves F-order, so an F-contiguous X reaches the path; result must # match the C-order run on identical data. @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -1459,7 +1460,19 @@ def run(arr): ) -@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr"]) +def test_wilcoxon_device_sparse_bool_data_raises(): + counts = np.arange(400).reshape(100, 4) % 3 == 0 + mat = cpsp.csr_matrix(cp.asarray(counts)) + adata = sc.AnnData( + X=mat, + obs=pd.DataFrame({"group": pd.Categorical([f"{i % 2}" for i in range(100)])}), + var=pd.DataFrame(index=[f"g{j}" for j in range(4)]), + ) + with pytest.raises(TypeError, match="float32 or float64"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) def test_wilcoxon_sparse_float16_data_raises(fmt): # Unsupported float16 sparse data (host + device) is rejected with TypeError. rng = np.random.default_rng(0) @@ -1476,6 +1489,21 @@ def test_wilcoxon_sparse_float16_data_raises(fmt): rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_wilcoxon_sparse_complex_data_raises(fmt): + rng = np.random.default_rng(4) + dense = np.abs(rng.standard_normal((40, 4))).astype(np.float32) + dense[dense < 0.4] = 0.0 + mat = _to_format(dense.astype(np.complex64), fmt) + adata = sc.AnnData( + X=mat, + obs=pd.DataFrame({"group": pd.Categorical([f"{i % 2}" for i in range(40)])}), + var=pd.DataFrame(index=[f"g{j}" for j in range(4)]), + ) + with pytest.raises(TypeError, match="complex sparse data is not supported"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + @pytest.mark.parametrize("reference", ["rest", "2"]) def test_wilcoxon_group_subset_column_order_matches_scanpy(reference): """Output column order must echo the user's ``groups=`` list (scanpy parity), @@ -1596,39 +1624,33 @@ def test_wilcoxon_fdr_ties_nan_match_scanpy(): ) -def _promote_host_index_dtypes(X, *, indptr64, indices64): +def _promote_host_index_dtype(X): """Copy a host scipy CSR/CSC matrix with promoted index-array dtypes. - scipy couples indptr/indices to one dtype via get_index_dtype, so the - decoupled (i64 indptr / i32 indices) combination only arises by explicit - promotion -- which is exactly what drives the templated host kernels. + scipy couples indptr/indices to one dtype via get_index_dtype. Real int64 + index buffers only occur at nnz > 2^31 in practice, so tests promote a small + matrix explicitly to drive the int64 templates. """ X = X.copy() - if indptr64: - X.indptr = X.indptr.astype(np.int64) - if indices64: - X.indices = X.indices.astype(np.int64) + X.indptr = X.indptr.astype(np.int64) + X.indices = X.indices.astype(np.int64) return X @pytest.mark.parametrize("reference", ["rest", "1"]) # OVR vs OVO host paths @pytest.mark.parametrize( - ("layout", "data_dtype", "indices64"), + ("layout", "data_dtype"), [ - ("csr", np.float32, False), # *_i64 - ("csr", np.float32, True), # *_i64_idx64 - ("csr", np.float64, False), # *_f64_i64 - ("csr", np.float64, True), # *_f64_i64_idx64 - ("csc", np.float32, False), # *_i64 (CSC has no idx64 template) - ("csc", np.float64, False), # *_f64_i64 + ("csr", np.float32), + ("csr", np.float64), + ("csc", np.float32), + ("csc", np.float64), ], ) -def test_host_sparse_int64_templates_match_int32( - reference, layout, data_dtype, indices64 -): +def test_host_sparse_int64_templates_match_int32(reference, layout, data_dtype): """Exercise the host-sparse int64-indptr / int64-indices kernel templates - (the 12 ``*_i64`` / ``*_idx64`` / ``*_f64_i64`` host bindings the suite - otherwise never reaches). These differ from the validated int32 host path + (the int64-index/indptr overloads the suite otherwise never reaches). + These differ from the validated int32 host path only in index dtype, so they must be bit-identical to it. Real int64 indices only occur at nnz > 2^31 (unallocatable in CI), so we promote a small matrix's index arrays explicitly and keep it host-resident (scipy sparse + @@ -1644,7 +1666,7 @@ def test_host_sparse_int64_templates_match_int32( a32 = sc.AnnData(X=base.copy(), obs=obs.copy(), var=var.copy()) a64 = sc.AnnData( - X=_promote_host_index_dtypes(base, indptr64=True, indices64=indices64), + X=_promote_host_index_dtype(base), obs=obs.copy(), var=var.copy(), ) @@ -2172,30 +2194,25 @@ def test_ovr_device_sparse_subset_match_scanpy(fmt): @pytest.mark.parametrize("reference", ["rest", "1"]) -def test_host_csc_int64_indices_cast_matches_int32(reference): - """Host CSC has no *_idx64 template, so int64 indices are cast to int32 - (_wilcoxon.py:355->357). Result must be bit-identical to the int32 input.""" +@pytest.mark.parametrize("layout", ["csr", "csc"]) +def test_host_sparse_mismatched_index_dtype_raises(reference, layout): + """Host sparse indices/indptr must keep scipy's same-dtype invariant.""" rng = np.random.default_rng(15) dense = rng.integers(0, 5, size=(120, 6)).astype(np.float64) dense[dense < 1.0] = 0.0 obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(120)])}) var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) - base = sp.csc_matrix(dense) - a32 = sc.AnnData(X=base.copy(), obs=obs.copy(), var=var.copy()) - m64 = base.copy() + maker = sp.csr_matrix if layout == "csr" else sp.csc_matrix + m64 = maker(dense) m64.indices = m64.indices.astype(np.int64) # keep indptr int32 - a64 = sc.AnnData(X=m64, obs=obs.copy(), var=var.copy()) + assert m64.indptr.dtype == np.int32 + assert m64.indices.dtype == np.int64 + adata = sc.AnnData(X=m64, obs=obs.copy(), var=var.copy()) kw = { "method": "wilcoxon", "use_raw": False, "reference": reference, "tie_correct": True, } - rsc.tl.rank_genes_groups(a32, "group", **kw) - rsc.tl.rank_genes_groups(a64, "group", **kw) - r32, r64 = a32.uns["rank_genes_groups"], a64.uns["rank_genes_groups"] - for fld in ("scores", "pvals"): - for grp in r32[fld].dtype.names: - np.testing.assert_array_equal( - np.asarray(r64[fld][grp]), np.asarray(r32[fld][grp]) - ) + with pytest.raises(TypeError, match="indices and indptr must have the same dtype"): + rsc.tl.rank_genes_groups(adata, "group", **kw) From 328697a9aaebb5e950896e9c7c8dc695f52d982d Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 26 Jun 2026 13:38:36 +0200 Subject: [PATCH 35/36] make negative fall bag better Signed-off-by: Intron7 --- .../_cuda/sparse_extract/sparse_extract.cuh | 14 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 477 ++++++++++++++++ .../_cuda/wilcoxon/wilcoxon_sparse.cu | 94 +++ .../tools/_rank_genes_groups/__init__.py | 10 +- .../tools/_rank_genes_groups/_core.py | 5 +- .../tools/_rank_genes_groups/_utils.py | 39 +- .../tools/_rank_genes_groups/_wilcoxon.py | 540 +++++++++++------- 7 files changed, 918 insertions(+), 261 deletions(-) diff --git a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh index 2496d16c..3c01a5de 100644 --- a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh +++ b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh @@ -54,11 +54,12 @@ __global__ void csr_scatter_to_csc_kernel( // `out` must be pre-zeroed; atomicAdd sums duplicate column indices (like // scipy's sum_duplicates) -- bit-identical to dense materialization for // canonical CSR. Output always double; input dtype templated. -template +template __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, const IndexT* __restrict__ indices, const TData* __restrict__ data, - double* __restrict__ out, int col_lb, + OutT* __restrict__ out, int col_lb, int col_ub, int n_cells) { const long long row = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; @@ -75,7 +76,7 @@ __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, const IndexT col = indices[k]; if (col >= lb && col < ub) { atomicAdd(&out[static_cast(col - lb) * n_cells + row], - static_cast(data[k])); + static_cast(data[k])); } } } @@ -87,11 +88,12 @@ __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, // // `out` must be pre-zeroed. `indptr` indexes columns; pass full-matrix column // pointers (with col_lb/col_ub) or a window rebased to [0, col_ub-col_lb). -template +template __global__ void csc_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, const IndexT* __restrict__ indices, const TData* __restrict__ data, - double* __restrict__ out, int col_lb, + OutT* __restrict__ out, int col_lb, int col_ub, int n_cells) { const int col = col_lb + static_cast(blockIdx.x); if (col >= col_ub) return; @@ -100,7 +102,7 @@ __global__ void csc_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, const IndptrT e = indptr[col + 1]; for (IndptrT p = s + threadIdx.x; p < e; p += blockDim.x) { const long long row = static_cast(indices[p]); - out[col_local * n_cells + row] = static_cast(data[p]); + out[col_local * n_cells + row] = static_cast(data[p]); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 8aac2f45..e1c4d7ed 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -762,6 +762,483 @@ static void ovr_sparse_csr_host_streaming_impl( sync_streams(streams, "sparse host CSR streaming"); } +// ============================================================================ +// Sign-safe sparse OVR path: sparse window -> dense f32 tile -> dense rank. +// ============================================================================ + +constexpr int SPARSE_DENSE_OVR_CHUNK_COLS = 512; + +static void launch_ovr_dense_rank_window( + const float* dense, const int* group_codes, double* rank_sums, + double* tie_corr, int out_col, int n_rows, int n_cols_total, + int window_cols, int n_groups, bool compute_tie_corr, + int rank_sub_batch_cols, cudaStream_t upstream_stream) { + if (n_rows == 0 || window_cols == 0 || n_groups == 0) return; + if (rank_sub_batch_cols <= 0) rank_sub_batch_cols = SUB_BATCH_COLS; + + DenseColumnBatchPlan batches = plan_dense_column_batches( + n_rows, window_cols, rank_sub_batch_cols, SAFE_BATCH_NNZ, + "Sparse-dense OVR rank sub-batch"); + rank_sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(window_cols, rank_sub_batch_cols); + size_t sub_items = batches.max_items; + int sub_items_i32 = + checked_cub_items(sub_items, "Sparse-dense OVR rank sub-batch"); + size_t cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(sub_items_i32, rank_sub_batch_cols); + + RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); + ScopedCudaEvent inputs_ready(cudaEventDisableTiming); + inputs_ready.record(upstream_stream); + for (int s = 0; s < n_streams; s++) { + cuda_check(cudaStreamWaitEvent(streams[s], inputs_ready.get(), 0), + "wait on sparse-dense OVR tile"); + } + + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(rank_sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * rank_sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(rank_sub_batch_cols); + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + int col = 0; + int batch_idx = 0; + while (col < window_cols) { + int sb_cols = std::min(rank_sub_batch_cols, window_cols - col); + int sb_items = + checked_int_product((size_t)n_rows, (size_t)sb_cols, + "Sparse-dense OVR active rank sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + const float* keys_in = dense + (size_t)col * n_rows; + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, stream, "sparse-dense OVR segmented sort"); + + if (use_gmem) { + cuda_check(cudaMemsetAsync( + buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream), + "sparse-dense OVR gmem rank_sums memset"); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + scatter_cols_2d(rank_sums + out_col + col, buf.sub_rank_sums, n_groups, + n_cols_total, sb_cols, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + out_col + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + sync_streams(streams, "sparse-dense OVR rank"); +} + +template +static void ovr_dense_csr_streaming_impl( + const float* csr_data, const IndexT* csr_indices, const IndptrT* csr_indptr, + const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, int chunk_cols, + int rank_sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; + + DenseColumnBatchPlan chunks = + plan_dense_column_batches(n_rows, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Device CSR sparse-dense OVR column chunk"); + chunk_cols = chunks.sub_batch_cols; + size_t max_dense_items = (size_t)n_rows * (size_t)chunk_cols; + + RmmScratchPool pool; + ScopedCudaStream extract_stream(cudaStreamDefault); + cudaStream_t stream = extract_stream.get(); + float* dense = pool.alloc(max_dense_items); + + for (int col = 0; col < n_cols; col += chunk_cols) { + int sb_cols = std::min(chunk_cols, n_cols - col); + size_t dense_items = (size_t)n_rows * (size_t)sb_cols; + cudaMemsetAsync(dense, 0, dense_items * sizeof(float), stream); + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_tile_to_dense_kernel + <<>>(csr_indptr, csr_indices, + csr_data, dense, col, + col + sb_cols, n_rows); + CUDA_CHECK_LAST_ERROR(csr_tile_to_dense_kernel); + launch_ovr_dense_rank_window( + dense, group_codes, rank_sums, tie_corr, col, n_rows, n_cols, + sb_cols, n_groups, compute_tie_corr, rank_sub_batch_cols, stream); + } +} + +template +static void ovr_dense_csc_streaming_impl( + const float* csc_data, const IndexT* csc_indices, const IndptrT* csc_indptr, + const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, int chunk_cols, + int rank_sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; + + DenseColumnBatchPlan chunks = + plan_dense_column_batches(n_rows, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Device CSC sparse-dense OVR column chunk"); + chunk_cols = chunks.sub_batch_cols; + size_t max_dense_items = (size_t)n_rows * (size_t)chunk_cols; + + RmmScratchPool pool; + ScopedCudaStream extract_stream(cudaStreamDefault); + cudaStream_t stream = extract_stream.get(); + float* dense = pool.alloc(max_dense_items); + + for (int col = 0; col < n_cols; col += chunk_cols) { + int sb_cols = std::min(chunk_cols, n_cols - col); + size_t dense_items = (size_t)n_rows * (size_t)sb_cols; + cudaMemsetAsync(dense, 0, dense_items * sizeof(float), stream); + csc_tile_to_dense_kernel + <<>>(csc_indptr, csc_indices, + csc_data, dense, col, + col + sb_cols, n_rows); + CUDA_CHECK_LAST_ERROR(csc_tile_to_dense_kernel); + launch_ovr_dense_rank_window( + dense, group_codes, rank_sums, tie_corr, col, n_rows, n_cols, + sb_cols, n_groups, compute_tie_corr, rank_sub_batch_cols, stream); + } +} + +template +static void ovr_dense_csc_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, double* d_rank_sums, double* d_tie_corr, + double* d_group_sums, double* d_group_nnz, double* d_total_sums, + double* d_total_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_stats, bool compute_nnz, + bool compute_totals, int chunk_cols, int rank_sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; + compute_nnz = compute_stats && compute_nnz && d_group_nnz != nullptr; + compute_totals = compute_stats && compute_totals && d_total_sums != nullptr; + + DenseColumnBatchPlan dense_chunks = + plan_dense_column_batches(n_rows, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Host CSC sparse-dense OVR column chunk"); + chunk_cols = dense_chunks.sub_batch_cols; + ColumnBatchPlan batches = + plan_csc_column_batches(h_indptr, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Host CSC sparse-dense OVR offsets"); + chunk_cols = batches.sub_batch_cols; + size_t max_nnz = batches.max_nnz; + size_t max_dense_items = (size_t)n_rows * (size_t)chunk_cols; + + RmmScratchPool pool; + PinnedRing stage(1, max_nnz ? max_nnz : 1); + ScopedCudaStream extract_stream(cudaStreamDefault); + cudaStream_t stream = extract_stream.get(); + + int* d_group_codes = pool.alloc(n_rows); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + int* d_all_offsets = upload_batch_offsets(batches, pool); + + InT* d_sparse_data_orig = pool.alloc(max_nnz ? max_nnz : 1); + float* d_sparse_data_f32 = pool.alloc(max_nnz ? max_nnz : 1); + IndexT* d_sparse_indices = pool.alloc(max_nnz ? max_nnz : 1); + int* idx_i32 = (sizeof(IndexT) > sizeof(int)) + ? pool.alloc(max_nnz ? max_nnz : 1) + : nullptr; + int* d_indptr = pool.alloc(chunk_cols + 1); + float* dense = pool.alloc(max_dense_items); + double* sub_group_sums = + compute_stats ? pool.alloc((size_t)n_groups * chunk_cols) + : nullptr; + double* sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * chunk_cols) + : nullptr; + double* sub_total_sums = + compute_totals ? pool.alloc(chunk_cols) : nullptr; + double* sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(chunk_cols) + : nullptr; + + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); + + int col = 0; + for (int b = 0; b < batches.n_batches; b++) { + int sb_cols = std::min(chunk_cols, n_cols - col); + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "Host CSC sparse-dense active nnz"); + + stage.wait(0); + if (batch_nnz > 0) { + host_copy_slice(h_data, h_indices, (size_t)ptr_start, batch_nnz, + stage.template get<0>(0), stage.template get<1>(0)); + cudaMemcpyAsync(d_sparse_data_orig, stage.template get<0>(0), + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_sparse_indices, stage.template get<1>(0), + (size_t)batch_nnz * sizeof(IndexT), + cudaMemcpyHostToDevice, stream); + } + stage.record(0, stream); + + int* idx32; + if constexpr (sizeof(IndexT) > sizeof(int)) { + if (batch_nnz > 0) { + int cblk = (batch_nnz + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + cast_array_kernel + <<>>( + d_sparse_indices, idx_i32, (size_t)batch_nnz); + CUDA_CHECK_LAST_ERROR(cast_array_kernel); + } + idx32 = idx_i32; + } else { + idx32 = reinterpret_cast(d_sparse_indices); + } + + int* src = d_all_offsets + (size_t)b * (chunk_cols + 1); + cudaMemcpyAsync(d_indptr, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + launch_ovr_cast_and_accumulate_sparse( + d_sparse_data_orig, d_sparse_data_f32, idx32, d_indptr, + d_group_codes, sub_group_sums, sub_group_nnz, sub_total_sums, + sub_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals, + UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + + size_t dense_items = (size_t)n_rows * (size_t)sb_cols; + cudaMemsetAsync(dense, 0, dense_items * sizeof(float), stream); + csc_tile_to_dense_kernel + <<>>( + d_indptr, idx32, d_sparse_data_f32, dense, 0, sb_cols, n_rows); + CUDA_CHECK_LAST_ERROR(csc_tile_to_dense_kernel); + + if (compute_stats) { + scatter_cols_2d(d_group_sums + col, sub_group_sums, n_groups, + n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(d_group_nnz + col, sub_group_nnz, n_groups, + n_cols, sb_cols, stream); + } + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, sub_total_sums, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } + } + + launch_ovr_dense_rank_window( + dense, d_group_codes, d_rank_sums, d_tie_corr, col, n_rows, n_cols, + sb_cols, n_groups, compute_tie_corr, rank_sub_batch_cols, stream); + col += sb_cols; + } +} + +template +static void ovr_dense_csr_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, double* d_rank_sums, double* d_tie_corr, + double* d_group_sums, double* d_group_nnz, double* d_total_sums, + double* d_total_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_stats, bool compute_nnz, + bool compute_totals, int chunk_cols, int rank_sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; + compute_nnz = compute_stats && compute_nnz && d_group_nnz != nullptr; + compute_totals = compute_stats && compute_totals && d_total_sums != nullptr; + + DenseColumnBatchPlan dense_chunks = + plan_dense_column_batches(n_rows, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Host CSR sparse-dense OVR column chunk"); + chunk_cols = dense_chunks.sub_batch_cols; + + std::vector h_col_counts(n_cols, 0); + { + int n_workers = host_worker_count(); + std::vector> local(n_workers, + std::vector(n_cols, 0)); + int used = host_parallel_chunks(n_rows, [&](int w, int r0, int r1) { + std::vector& lc = local[w]; + for (IndptrT p = h_indptr[r0]; p < h_indptr[r1]; p++) + lc[(size_t)h_indices[p]]++; + }); + for (int w = 0; w < used; w++) + for (int c = 0; c < n_cols; c++) h_col_counts[c] += local[w][c]; + } + + size_t cap = SAFE_BATCH_NNZ; + ColumnBatchPlan batches = plan_column_batches_from_counts( + n_cols, chunk_cols, cap, [&](int c) { return h_col_counts[c]; }, + "Host CSR sparse-dense OVR offsets"); + chunk_cols = batches.sub_batch_cols; + size_t max_batch_nnz = batches.max_nnz; + size_t max_dense_items = (size_t)n_rows * (size_t)chunk_cols; + + RmmScratchPool pool; + ScopedCudaStream extract_stream(cudaStreamDefault); + cudaStream_t stream = extract_stream.get(); + PinnedRing gather_stage(1, max_batch_nnz ? max_batch_nnz : 1); + PinnedRing indptr_stage(1, (size_t)n_rows + 1); + std::vector cursor(n_rows, 0); + std::vector row_counts(n_rows); + + int* d_group_codes = pool.alloc(n_rows); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + int* d_all_offsets = upload_batch_offsets(batches, pool); + + InT* d_gather_vals = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + int* d_gather_cols = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + int* d_gather_indptr = pool.alloc(n_rows + 1); + int* col_offsets = pool.alloc(chunk_cols + 1); + int* write_pos = pool.alloc(chunk_cols); + InT* csc_vals_orig = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + float* csc_vals_f32 = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + int* csc_row_idx = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + float* dense = pool.alloc(max_dense_items); + double* sub_group_sums = + compute_stats ? pool.alloc((size_t)n_groups * chunk_cols) + : nullptr; + double* sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * chunk_cols) + : nullptr; + double* sub_total_sums = + compute_totals ? pool.alloc(chunk_cols) : nullptr; + double* sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(chunk_cols) + : nullptr; + + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); + + int col = 0; + for (int b = 0; b < batches.n_batches; b++) { + int sb_cols = std::min(chunk_cols, n_cols - col); + int col_end = col + sb_cols; + gather_stage.wait(0); + indptr_stage.wait(0); + InT* h_gather_vals = gather_stage.template get<0>(0); + int* h_gather_cols = gather_stage.template get<1>(0); + int* h_gather_indptr = indptr_stage.template get<0>(0); + + int batch_nnz = host_materialize_csr_column_interval_cursor( + h_data, h_indices, h_indptr, n_rows, col, col_end, cursor.data(), + row_counts.data(), h_gather_indptr, h_gather_vals, h_gather_cols, + "Host CSR sparse-dense gather nnz"); + + int* src = d_all_offsets + (size_t)b * (chunk_cols + 1); + cudaMemcpyAsync(col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + cudaMemcpyAsync(d_gather_vals, h_gather_vals, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_gather_cols, h_gather_cols, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + cudaMemcpyAsync(d_gather_indptr, h_gather_indptr, + (n_rows + 1) * sizeof(int), cudaMemcpyHostToDevice, + stream); + gather_stage.record(0, stream); + indptr_stage.record(0, stream); + + if (batch_nnz > 0) { + csr_scatter_to_csc_kernel + <<<(n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE, + UTIL_BLOCK_SIZE, 0, stream>>>( + d_gather_vals, d_gather_cols, d_gather_indptr, write_pos, + csc_vals_orig, csc_row_idx, n_rows, col, col_end, 0); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + } + + launch_ovr_cast_and_accumulate_sparse( + csc_vals_orig, csc_vals_f32, csc_row_idx, col_offsets, + d_group_codes, sub_group_sums, sub_group_nnz, sub_total_sums, + sub_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals, + UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + + size_t dense_items = (size_t)n_rows * (size_t)sb_cols; + cudaMemsetAsync(dense, 0, dense_items * sizeof(float), stream); + csc_tile_to_dense_kernel + <<>>(col_offsets, csc_row_idx, + csc_vals_f32, dense, 0, + sb_cols, n_rows); + CUDA_CHECK_LAST_ERROR(csc_tile_to_dense_kernel); + + if (compute_stats) { + scatter_cols_2d(d_group_sums + col, sub_group_sums, n_groups, + n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(d_group_nnz + col, sub_group_nnz, n_groups, + n_cols, sb_cols, stream); + } + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, sub_total_sums, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } + } + + launch_ovr_dense_rank_window( + dense, d_group_codes, d_rank_sums, d_tie_corr, col, n_rows, n_cols, + sb_cols, n_groups, compute_tie_corr, rank_sub_batch_cols, stream); + col += sb_cols; + } +} + // ============================================================================ // Sparse-aware CSC OVR streaming (sort only stored nonzeros) // ============================================================================ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index 55d28b63..e3a1032c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -5,6 +5,7 @@ #include "../nb_types.h" #include "wilcoxon_fast_common.cuh" +#include "kernels_wilcoxon.cuh" #include "wilcoxon_sparse_kernels.cuh" #include "wilcoxon_ovr_kernels.cuh" #include "wilcoxon_ovr_sparse.cuh" @@ -131,6 +132,99 @@ void register_sparse_bindings(nb::module_& m) { int64_t); #undef RSC_OVR_SPARSE_CSR_HOST_BINDING +#define RSC_OVR_DENSE_DEVICE_BINDING(NAME, IMPL, IndexCType, IndptrCType) \ + m.def( \ + NAME, \ + [](gpu_array_c data, \ + gpu_array_c indices, \ + gpu_array_c indptr, \ + gpu_array_c group_codes, \ + gpu_array_c rank_sums, \ + gpu_array_c tie_corr, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, int chunk_cols, \ + int rank_sub_batch_cols) { \ + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; \ + if (rank_sub_batch_cols <= 0) \ + rank_sub_batch_cols = SUB_BATCH_COLS; \ + IMPL(data.data(), indices.data(), indptr.data(), \ + group_codes.data(), rank_sums.data(), tie_corr.data(), \ + n_rows, n_cols, n_groups, compute_tie_corr, chunk_cols, \ + rank_sub_batch_cols); \ + }, \ + "data"_a, "indices"_a, "indptr"_a, "group_codes"_a, "rank_sums"_a, \ + "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "compute_tie_corr"_a, "chunk_cols"_a = SPARSE_DENSE_OVR_CHUNK_COLS, \ + "rank_sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_DENSE_DEVICE_BINDING("ovr_dense_csc_device", + ovr_dense_csc_streaming_impl, int, int); + RSC_OVR_DENSE_DEVICE_BINDING( + "ovr_dense_csc_device", ovr_dense_csc_streaming_impl, int64_t, int64_t); + RSC_OVR_DENSE_DEVICE_BINDING("ovr_dense_csr_device", + ovr_dense_csr_streaming_impl, int, int); + RSC_OVR_DENSE_DEVICE_BINDING( + "ovr_dense_csr_device", ovr_dense_csr_streaming_impl, int64_t, int64_t); +#undef RSC_OVR_DENSE_DEVICE_BINDING + +#define RSC_OVR_DENSE_HOST_BINDING(NAME, IMPL, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_nnz, \ + gpu_array_c d_total_sums, \ + gpu_array_c d_total_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_stats, \ + bool compute_nnz, bool compute_totals, int chunk_cols, \ + int rank_sub_batch_cols) { \ + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; \ + if (rank_sub_batch_cols <= 0) \ + rank_sub_batch_cols = SUB_BATCH_COLS; \ + IMPL(h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), d_rank_sums.data(), d_tie_corr.data(), \ + d_group_sums.data(), d_group_nnz.data(), d_total_sums.data(), \ + d_total_nnz.data(), n_rows, n_cols, n_groups, \ + compute_tie_corr, compute_stats, compute_nnz, compute_totals, \ + chunk_cols, rank_sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, "d_group_nnz"_a, \ + "d_total_sums"_a, "d_total_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "compute_stats"_a, \ + "compute_nnz"_a = true, "compute_totals"_a = false, \ + "chunk_cols"_a = SPARSE_DENSE_OVR_CHUNK_COLS, \ + "rank_sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csc_host", + ovr_dense_csc_host_streaming_impl, float, int, + int); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csc_host", + ovr_dense_csc_host_streaming_impl, double, int, + int); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csc_host", + ovr_dense_csc_host_streaming_impl, float, + int64_t, int64_t); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csc_host", + ovr_dense_csc_host_streaming_impl, double, + int64_t, int64_t); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csr_host", + ovr_dense_csr_host_streaming_impl, float, int, + int); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csr_host", + ovr_dense_csr_host_streaming_impl, double, int, + int); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csr_host", + ovr_dense_csr_host_streaming_impl, float, + int64_t, int64_t); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csr_host", + ovr_dense_csr_host_streaming_impl, double, + int64_t, int64_t); +#undef RSC_OVR_DENSE_HOST_BINDING + #define RSC_OVO_DEVICE_BINDING(NAME, IMPL, IndexCType, IndptrCType) \ m.def( \ NAME, \ diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 71a43f9f..0ef1f7eb 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -76,10 +76,12 @@ def rank_genes_groups( Rank genes for characterizing groups using GPU acceleration. Log1p/log-normalized data is expected for biologically meaningful log fold - changes. In-memory sparse inputs with explicit negative values fall back to - the dense full-sort ranking path; dense inputs are ranked directly and - support any sign. (``wilcoxon_binned`` rejects negative Dask sparse input, - which it cannot bin correctly.) + changes. In-memory sparse ``wilcoxon`` inputs with explicit negative values + use sign-safe dense ranking in the CUDA sparse streamers, materializing + bounded dense tiles inside the nanobind path. Dense inputs are ranked + directly and support any sign. + (``wilcoxon_binned`` rejects negative Dask sparse input, which it cannot + bin correctly.) .. note:: **Dask support:** `'t-test'`, `'t-test_overestim_var'`, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index 89a5dfde..e68e02c7 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -378,9 +378,8 @@ def compute_statistics( # as a tie at the column minimum (valid only for nonnegative data). # t-test/logreg are mean/variance/model-based and sign-agnostic. For the # Wilcoxon methods we canonicalize and, when sparse data holds - # negatives, fall back to the dense full-sort ranking (correct for any - # sign) rather than erroring -- so e.g. signed sparse data still ranks - # correctly, just via the dense path. + # negatives, route to sign-safe dense ranking inside the sparse + # streamers rather than erroring. self._sparse_negative_fallback = False if method in {"wilcoxon", "wilcoxon_binned"}: # Canonicalize before the negative check: summing duplicates can diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 32a59452..ecd93b10 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -20,8 +20,8 @@ def _sparse_has_negative(X) -> bool: The fast sparse Wilcoxon paths add implicit (structural) zeros as a tie at the column minimum, which is correct only for nonnegative stored values. A - negative breaks that, so the in-memory Wilcoxon paths fall back to the dense - full-sort path (valid for any sign). Dask arrays are not inspected here + negative breaks that, so in-memory Wilcoxon routes signed sparse data to + the sign-safe sparse-dense ranker. Dask arrays are not inspected here (they are neither ``scipy`` nor ``cupy`` sparse); ``wilcoxon_binned`` guards Dask sparse separately. Dense and t-test/logreg never need this. """ @@ -224,38 +224,3 @@ def _get_column_block(X, start: int, stop: int) -> cp.ndarray: return cp.asarray(X[:, start:stop], dtype=cp.float64, order="F") case _: raise ValueError(f"Unsupported matrix type: {type(X)}") - - -def _ovr_dense_block_f32(X, start: int, stop: int) -> cp.ndarray: - """OVR (vs-rest): ALL cells x gene-window, F-order float32. - - For sparse X (the negative-values dense fallback) the window is densified on - the fly via the shared CSR/CSC densify path (`_get_column_block`), so no - full-matrix dense materialization happens. - """ - if isinstance(X, np.ndarray | cp.ndarray): - return cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") - if sp.issparse(X) or cpsp.issparse(X): - block = _get_column_block(X, start, stop) # float64 F-order chunk - return cp.asfortranarray(block.astype(cp.float32, copy=False)) - raise TypeError(f"Expected dense matrix, got {type(X)}") - - -def _ovo_dense_block(X, row_ids: np.ndarray, start: int, stop: int) -> cp.ndarray: - """OVO (with-reference): a ROW SUBSET (`row_ids`) x gene-window, F-order. - - OVO ranks the reference group against each other group, so it materializes - only the selected rows -- unlike `_ovr_dense_block_f32`, which takes all - cells. - """ - if isinstance(X, np.ndarray): - return cp.asarray(X[row_ids, start:stop], order="F") - if isinstance(X, cp.ndarray): - rows = cp.asarray(row_ids, dtype=cp.int32) - return cp.asfortranarray(X[rows, start:stop]) - if isinstance(X, sp.spmatrix | sp.sparray): - return cp.asarray(X[row_ids][:, start:stop].toarray(), order="F") - if cpsp.issparse(X): - rows = cp.asarray(row_ids, dtype=cp.int32) - return cp.asfortranarray(X[rows][:, start:stop].toarray()) - raise TypeError(f"Unsupported matrix type: {type(X)}") diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index eababe07..23f15ca0 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -17,9 +17,6 @@ EPS, MIN_GROUP_SIZE_WARNING, _choose_chunk_size, - _get_column_block, - _ovo_dense_block, - _ovr_dense_block_f32, ) if TYPE_CHECKING: @@ -60,46 +57,6 @@ def _choose_wilcoxon_chunk_size(requested: int | None, n_genes: int) -> int: return min(DEFAULT_WILCOXON_CHUNK_SIZE, max(1, n_genes)) -def _fill_ovo_chunk_stats( - rg: _RankGenes, - ref_block: cp.ndarray, - grp_block: cp.ndarray, - *, - offsets: np.ndarray, - test_group_indices: list[int], - start: int, - stop: int, - group_sizes: NDArray, -) -> None: - if not rg._compute_stats_in_chunks: - return - - ireference = rg.ireference - n_ref = int(group_sizes[ireference]) - ref_mean = ref_block.mean(axis=0) - rg.means[ireference, start:stop] = cp.asnumpy(ref_mean) - if n_ref > 1: - rg.vars[ireference, start:stop] = cp.asnumpy(ref_block.var(axis=0, ddof=1)) - if rg.comp_pts: - ref_nnz = (ref_block != 0).sum(axis=0) - rg.pts[ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) - - for slot, group_index in enumerate(test_group_indices): - begin = int(offsets[slot]) - end = int(offsets[slot + 1]) - n_group = int(group_sizes[group_index]) - group_block = grp_block[begin:end] - group_mean = group_block.mean(axis=0) - rg.means[group_index, start:stop] = cp.asnumpy(group_mean) - if n_group > 1: - rg.vars[group_index, start:stop] = cp.asnumpy( - group_block.var(axis=0, ddof=1) - ) - if rg.comp_pts: - group_nnz = (group_block != 0).sum(axis=0) - rg.pts[group_index, start:stop] = cp.asnumpy(group_nnz / n_group) - - def _fill_basic_stats_from_accumulators( rg: _RankGenes, group_sums: cp.ndarray, @@ -437,18 +394,17 @@ def _validate_wilcoxon_sparse_dtype(X) -> None: def _device_sparse_arrays(X): """Prepare device-sparse arrays for the Wilcoxon kernels. - Wilcoxon ranking sorts float32 keys on every path -- the sparse fast paths - AND the dense fallback (``_ovr_dense_block_f32``); the CUB segmented - sort is float-keyed throughout. Casting ``X.data`` to float32 here therefore - does not diverge from any float64 ranking path, because there is none. This - only loses precision when preprocessing ran in float64; float32-preprocessed - values (even if later stored as float64) are float32-exact, so ranking - matches scanpy bit-for-bit (~1e-13). For a fully float64 pipeline the - rank-derived scores/p-values match scanpy-on-float64 to ~1e-4 on - log-normalized data (below any significance threshold, no DE calls change), - while means and log fold changes are still computed in float64. See the - ``rank_genes_groups`` note on ranking precision. float64 input is accepted - to spare the caller a pre-cast. + Wilcoxon ranking sorts float32 keys on every sparse device path, including + the sign-safe sparse-dense OVR path. Casting ``X.data`` to float32 here + therefore does not diverge from any float64 ranking path, because there is + none. This only loses precision when preprocessing ran in float64; + float32-preprocessed values (even if later stored as float64) are + float32-exact, so ranking matches scanpy bit-for-bit (~1e-13). For a fully + float64 pipeline the rank-derived scores/p-values match scanpy-on-float64 + to ~1e-4 on log-normalized data (below any significance threshold, no DE + calls change), while means and log fold changes are still computed in + float64. See the ``rank_genes_groups`` note on ranking precision. float64 + input is accepted to spare the caller a pre-cast. """ data_dtype = np.dtype(X.data.dtype) if data_dtype == np.float32: @@ -521,8 +477,8 @@ def wilcoxon( ) -def _host_sparse_format(X, *, sparse_negative_fallback: bool) -> str | None: - if sparse_negative_fallback or not isinstance(X, sp.spmatrix | sp.sparray): +def _host_sparse_format(X) -> str | None: + if not isinstance(X, sp.spmatrix | sp.sparray): return None if X.format not in {"csr", "csc"}: raise TypeError( @@ -532,9 +488,7 @@ def _host_sparse_format(X, *, sparse_negative_fallback: bool) -> str | None: return X.format -def _device_sparse_format(X, *, sparse_negative_fallback: bool) -> str | None: - if sparse_negative_fallback: - return None +def _device_sparse_format(X) -> str | None: if cpsp.isspmatrix_csc(X): return "csc" if cpsp.isspmatrix_csr(X): @@ -705,9 +659,7 @@ def _run_ovr_host_sparse( use_continuity: bool, return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]] | None: - sparse_format = _host_sparse_format( - X, sparse_negative_fallback=rg._sparse_negative_fallback - ) + sparse_format = _host_sparse_format(X) if sparse_format is None: return None @@ -815,9 +767,7 @@ def _run_ovr_device_sparse( use_continuity: bool, return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]] | None: - sparse_format = _device_sparse_format( - X, sparse_negative_fallback=rg._sparse_negative_fallback - ) + sparse_format = _device_sparse_format(X) if sparse_format is None: return None @@ -873,6 +823,124 @@ def _run_ovr_device_sparse( ) +def _run_ovr_signed_sparse_dense( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + host_format = ( + _host_sparse_format(X) if isinstance(X, sp.spmatrix | sp.sparray) else None + ) + device_format = _device_sparse_format(X) + sparse_format = host_format or device_format + if sparse_format is None: + return None + + n_groups = len(rg.groups_order) + group_codes_np = rg.group_codes.astype(np.int32, copy=False) + group_codes_gpu = cp.asarray(group_codes_np, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + chunk_cols = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + + if host_format is not None: + X.sort_indices() + compute_stats = rg._compute_stats_in_chunks + compute_nnz = compute_stats and rg.comp_pts + compute_totals = bool(compute_stats and np.any(group_codes_np == n_groups)) + stats_shape = (n_groups, n_total_genes) if compute_stats else (1, 1) + group_sums = cp.empty(stats_shape, dtype=cp.float64) + group_nnz = cp.empty( + stats_shape if compute_nnz else (1, 1), + dtype=cp.float64, + ) + total_sums = cp.empty( + (1, n_total_genes) if compute_totals else (1, 1), + dtype=cp.float64, + ) + total_nnz = cp.empty( + (1, n_total_genes) if (compute_totals and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + runner = ( + _wcs.ovr_dense_csc_host if host_format == "csc" else _wcs.ovr_dense_csr_host + ) + runner( + _host_sparse_data_array(X), + X.indices, + X.indptr, + group_codes_np, + rank_sums, + tie_corr, + group_sums, + group_nnz, + total_sums, + total_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_stats=compute_stats, + compute_nnz=compute_nnz, + compute_totals=compute_totals, + chunk_cols=chunk_cols, + rank_sub_batch_cols=OVR_DENSE_SUB_BATCH, + ) + if compute_stats: + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes, + n_cells=n_cells, + total_sums=total_sums if compute_totals else None, + total_nnz=total_nnz if compute_totals and compute_nnz else None, + ) + else: + if isinstance(X, cpsp.spmatrix) and X.format == "csr": + X.sort_indices() + data, indices, indptr = _device_sparse_arrays(X) + runner = ( + _wcs.ovr_dense_csc_device + if device_format == "csc" + else _wcs.ovr_dense_csr_device + ) + runner( + data, + indices, + indptr, + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + chunk_cols=chunk_cols, + rank_sub_batch_cols=OVR_DENSE_SUB_BATCH, + ) + + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) + + def _run_ovr_host_dense( rg: _RankGenes, X, @@ -952,7 +1020,7 @@ def _run_ovr_host_dense( ) -def _run_ovr_dense_chunks( +def _run_ovr_device_dense( rg: _RankGenes, X, n_cells: int, @@ -963,34 +1031,28 @@ def _run_ovr_dense_chunks( use_continuity: bool, chunk_size: int | None, return_u_values: bool, -) -> list[tuple[int, NDArray, NDArray]]: +) -> list[tuple[int, NDArray, NDArray]] | None: + if not isinstance(X, cp.ndarray): + return None + n_groups = len(rg.groups_order) chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev - all_scores: dict[int, list] = {i: [] for i in range(n_groups)} - all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = ( + cp.empty(n_total_genes, dtype=cp.float64) + if tie_correct + else cp.ones(n_total_genes, dtype=cp.float64) + ) for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) - if rg._compute_stats_in_chunks: - block = _get_column_block(X, start, stop) - rg._accumulate_chunk_stats_vs_rest( - block, - start, - stop, - group_codes_dev=group_codes_gpu, - group_sizes_dev=group_sizes_dev, - n_cells=n_cells, - ) - block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) - else: - block_f32 = _ovr_dense_block_f32(X, start, stop) - + block_f32 = cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") n_cols = stop - start - rank_sums = cp.empty((n_groups, n_cols), dtype=cp.float64) - tie_corr = ( + sub_rank_sums = cp.empty((n_groups, n_cols), dtype=cp.float64) + sub_tie_corr = ( cp.empty(n_cols, dtype=cp.float64) if tie_correct else cp.ones(n_cols, dtype=cp.float64) @@ -998,8 +1060,8 @@ def _run_ovr_dense_chunks( _wc.ovr_rank_dense_streaming( block_f32, group_codes_gpu, - rank_sums, - tie_corr, + sub_rank_sums, + sub_tie_corr, n_rows=n_cells, n_cols=n_cols, n_groups=n_groups, @@ -1007,26 +1069,20 @@ def _run_ovr_dense_chunks( sub_batch_cols=OVR_DENSE_SUB_BATCH, stream=cp.cuda.get_current_stream().ptr, ) - scores, p_values = _ovr_z_pvals( - rank_sums, - group_sizes_dev, - rest_sizes, - n_cells, - tie_corr, - use_continuity=use_continuity, - return_u_values=return_u_values, - ) - scores_host = scores.get() - p_host = p_values.get() - - for idx in range(n_groups): - all_scores[idx].append(scores_host[idx]) - all_pvals[idx].append(p_host[idx]) + rank_sums[:, start:stop] = sub_rank_sums + if tie_correct: + tie_corr[start:stop] = sub_tie_corr - return [ - (gi, np.concatenate(all_scores[gi]), np.concatenate(all_pvals[gi])) - for gi in range(n_groups) - ] + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) def _wilcoxon_vs_rest( @@ -1043,34 +1099,85 @@ def _wilcoxon_vs_rest( ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs rest of cells.""" _warn_small_ovr_groups(rg, group_sizes, n_cells) - for runner in ( - _run_ovr_host_sparse, - _run_ovr_device_sparse, - _run_ovr_host_dense, - ): - result = runner( - rg, - X, - n_cells, - n_total_genes, - group_sizes, - tie_correct=tie_correct, - use_continuity=use_continuity, - return_u_values=return_u_values, - ) - if result is not None: - return result - return _run_ovr_dense_chunks( - rg, - X, - n_cells, - n_total_genes, - group_sizes, - tie_correct=tie_correct, - use_continuity=use_continuity, - chunk_size=chunk_size, - return_u_values=return_u_values, - ) + match X: + case sp.spmatrix() | sp.sparray(): + if rg._sparse_negative_fallback: + result = _run_ovr_signed_sparse_dense( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + else: + result = _run_ovr_host_sparse( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case _ if _device_sparse_format(X) is not None: + if rg._sparse_negative_fallback: + result = _run_ovr_signed_sparse_dense( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + else: + result = _run_ovr_device_sparse( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case np.ndarray(): + result = _run_ovr_host_dense( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case cp.ndarray(): + result = _run_ovr_device_dense( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + case _: + msg = f"Unsupported Wilcoxon OVR input type: {type(X)}" + raise TypeError(msg) + if result is not None: + return result + msg = f"Unsupported Wilcoxon OVR input type: {type(X)}" + raise TypeError(msg) def _run_ovo_host_sparse( @@ -1084,9 +1191,7 @@ def _run_ovo_host_sparse( use_continuity: bool, return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]] | None: - sparse_format = _host_sparse_format( - X, sparse_negative_fallback=rg._sparse_negative_fallback - ) + sparse_format = _host_sparse_format(X) if sparse_format is None: return None @@ -1192,9 +1297,7 @@ def _run_ovo_device_sparse( use_continuity: bool, return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]] | None: - sparse_format = _device_sparse_format( - X, sparse_negative_fallback=rg._sparse_negative_fallback - ) + sparse_format = _device_sparse_format(X) if sparse_format is None: return None @@ -1327,49 +1430,39 @@ def _run_ovo_host_dense( ) -def _run_ovo_dense_chunks( +def _run_ovo_device_dense( rg: _RankGenes, X, ctx: _OvoContext, n_total_genes: int, - group_sizes: NDArray, + _group_sizes: NDArray, *, tie_correct: bool, use_continuity: bool, chunk_size: int | None, return_u_values: bool, -) -> list[tuple[int, NDArray, NDArray]]: - chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) - scores_host = np.empty((ctx.n_test, n_total_genes), dtype=np.float64) - pvals_host = np.empty((ctx.n_test, n_total_genes), dtype=np.float64) +) -> list[tuple[int, NDArray, NDArray]] | None: + if not isinstance(X, cp.ndarray): + return None + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + ref_rows = cp.asarray(ctx.ref_row_ids, dtype=cp.int32) + grp_rows = cp.asarray(ctx.all_grp_row_ids, dtype=cp.int32) + rank_sums = cp.empty((ctx.n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((ctx.n_test, n_total_genes), dtype=cp.float64) for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) n_cols = stop - start - ref_block = _ovo_dense_block(X, ctx.ref_row_ids, start, stop) - grp_block = _ovo_dense_block(X, ctx.all_grp_row_ids, start, stop) - - _fill_ovo_chunk_stats( - rg, - ref_block, - grp_block, - offsets=ctx.offsets_np, - test_group_indices=ctx.test_group_indices, - start=start, - stop=stop, - group_sizes=group_sizes, - ) - - ref_f32 = cp.asarray(ref_block, dtype=cp.float32, order="F") - grp_f32 = cp.asarray(grp_block, dtype=cp.float32, order="F") - rank_sums = cp.zeros((ctx.n_test, n_cols), dtype=cp.float64) - tie_corr = cp.ones((ctx.n_test, n_cols), dtype=cp.float64) + ref_f32 = cp.asarray(X[ref_rows, start:stop], dtype=cp.float32, order="F") + grp_f32 = cp.asarray(X[grp_rows, start:stop], dtype=cp.float32, order="F") + sub_rank_sums = cp.empty((ctx.n_test, n_cols), dtype=cp.float64) + sub_tie_corr = cp.ones((ctx.n_test, n_cols), dtype=cp.float64) _wc.ovo_rank_dense_tiered_unsorted_ref( ref_f32, grp_f32, ctx.offsets_gpu, - rank_sums, - tie_corr, + sub_rank_sums, + sub_tie_corr, n_ref=ctx.n_ref, n_all_grp=ctx.n_all_grp, n_cols=n_cols, @@ -1378,22 +1471,22 @@ def _run_ovo_dense_chunks( sub_batch_cols=OVO_DENSE_TIERED_SUB_BATCH, stream=cp.cuda.get_current_stream().ptr, ) - scores, p_values = _ovo_z_pvals( - rank_sums, - ctx.test_sizes, - ctx.n_ref, - tie_corr, - tie_correct=tie_correct, - use_continuity=use_continuity, - return_u_values=return_u_values, - ) - scores_host[:, start:stop] = scores.get() - pvals_host[:, start:stop] = p_values.get() + rank_sums[:, start:stop] = sub_rank_sums + if tie_correct: + tie_corr_arr[:, start:stop] = sub_tie_corr - return [ - (group_index, scores_host[slot], pvals_host[slot]) - for slot, group_index in enumerate(ctx.test_group_indices) - ] + return _finish_ovo( + rank_sums, + ctx.test_sizes, + ctx.n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=ctx.test_group_indices, + logfoldchanges_gpu=None, + ) def _wilcoxon_with_reference( @@ -1412,32 +1505,57 @@ def _wilcoxon_with_reference( if ctx.n_test == 0: return [] _warn_small_ovo_groups(rg, ctx, group_sizes) - for runner, extra in ( - (_run_ovo_host_sparse, {}), - (_run_ovo_device_sparse, {}), - (_run_ovo_host_dense, {"chunk_size": chunk_size}), - ): - result = runner( - rg, - X, - ctx, - n_total_genes, - group_sizes, - tie_correct=tie_correct, - use_continuity=use_continuity, - return_u_values=return_u_values, - **extra, - ) - if result is not None: - return result - return _run_ovo_dense_chunks( - rg, - X, - ctx, - n_total_genes, - group_sizes, - tie_correct=tie_correct, - use_continuity=use_continuity, - chunk_size=chunk_size, - return_u_values=return_u_values, - ) + match X: + case sp.spmatrix() | sp.sparray(): + result = _run_ovo_host_sparse( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case _ if _device_sparse_format(X) is not None: + result = _run_ovo_device_sparse( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case np.ndarray(): + result = _run_ovo_host_dense( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + case cp.ndarray(): + result = _run_ovo_device_dense( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + case _: + msg = f"Unsupported Wilcoxon OVO input type: {type(X)}" + raise TypeError(msg) + if result is not None: + return result + msg = f"Unsupported Wilcoxon OVO input type: {type(X)}" + raise TypeError(msg) From cff61e3922da031f08779e6a1be68dc953e9e31c Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 26 Jun 2026 13:52:25 +0200 Subject: [PATCH 36/36] slim down comments Signed-off-by: Intron7 --- .github/workflows/docker.yml | 5 +- .github/workflows/publish.yml | 17 +- CMakeLists.txt | 24 +-- pyproject.toml | 19 +- src/rapids_singlecell/_cuda/__init__.py | 10 +- src/rapids_singlecell/_cuda/nb_types.h | 37 +--- .../_cuda/rank_genes/rank_stats.cu | 11 +- src/rapids_singlecell/_cuda/rmm_scratch.cu | 25 +-- src/rapids_singlecell/_cuda/rmm_scratch.h | 8 +- .../_cuda/sparse_extract/sparse_extract.cuh | 57 ++---- .../_cuda/streaming/streaming.cuh | 60 +++---- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 9 +- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 59 +++---- .../_cuda/wilcoxon/wilcoxon.cu | 8 +- .../_cuda/wilcoxon/wilcoxon_block_reduce.cuh | 5 +- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 44 ++--- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 27 ++- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 48 ++---- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 41 ++--- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 43 ++--- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 99 +++-------- .../_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh | 8 +- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 50 ++---- .../tools/_rank_genes_groups/_core.py | 28 +-- .../tools/_rank_genes_groups/_logreg.py | 8 +- .../tools/_rank_genes_groups/_utils.py | 68 ++------ .../tools/_rank_genes_groups/_wilcoxon.py | 21 +-- .../_rank_genes_groups/_wilcoxon_binned.py | 87 ++-------- tests/test_rank_genes_groups_ttest.py | 13 +- tests/test_rank_genes_groups_wilcoxon.py | 163 ++++-------------- .../test_rank_genes_groups_wilcoxon_binned.py | 23 +-- 31 files changed, 290 insertions(+), 835 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 7cb46339..4f824ec6 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,6 +1,5 @@ -# This workflow will build two Docker image and push then to GitHub Packages Container registry: -# - a base image with the dependencies -# - a main image with the application code +# Build/push two GHCR images: dependency base and application image. +# Release events push; PR/comment runs only validate. name: Docker diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index efb3f656..9b086988 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -94,9 +94,8 @@ jobs: text = text.replace(f'rapids-cu{cuda} = [', 'rapids = [') text = remove_toml_array(text, f"rapids-cu{other}") - # librmm is needed at build time because CMake links the CUDA - # extension against librmm. Add the matching wheel to the isolated - # PEP 517 build requirements after selecting the CUDA package variant. + # CMake links CUDA extensions against librmm. + # Add the matching wheel to isolated build requirements. for dep in ( f' "librmm-cu{other}>=25.12",\n', f' "rmm-cu{other}>=25.12",\n', @@ -164,16 +163,8 @@ jobs: echo "[rsc-build] marker=$(cat build/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - # Exclude CUDA libs by SONAME glob (auditwheel >=6.2): the runtime - # stack (CuPy / nvidia-* wheels) provides them. Globs are version - # agnostic -- cusolver's SONAME is libcusolver.so.11 on CUDA 12 but - # .12 on CUDA 13, and nvJitLink is .12 vs .13, so pinning to the CUDA - # major would graft the wrong (or no) lib. cusolver's transitive deps - # (cublasLt, cusparse ~186MB, nvJitLink) are reached by auditwheel's - # tree walk and must each be excluded or they bloat the wheel. - # librmm.so / librapids_logger.so are also excluded: they are NOT in - # the CuPy/nvidia stack but are provided by the librmm / rapids_logger - # wheels at runtime, so we must not bundle them into our wheel. + # Exclude CUDA/RAPIDS runtime libs provided by dependency wheels. + # Use SONAME globs so CUDA 12/13 suffix changes do not bundle them. CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude 'libcublas.so.*' --exclude 'libcublasLt.so.*' --exclude 'libcudart.so.*' --exclude 'libcusolver.so.*' --exclude 'libcusparse.so.*' --exclude 'libnvJitLink.so.*' --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" CIBW_BUILD_VERBOSITY: "1" diff --git a/CMakeLists.txt b/CMakeLists.txt index 935945c3..63fb3f01 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,10 +52,8 @@ if (RSC_BUILD_EXTENSIONS) else() set(_rsc_python_rmm_hint "") endif() - # Wheel builds install librmm/rapids_logger into the isolated build env and - # write build/.librmm_dir from CIBW_BEFORE_BUILD. publish.yml also symlinks - # those shared libraries into /usr/local/lib so auditwheel can see and exclude - # them instead of bundling RAPIDS runtime libraries into the wheel. + # Wheel builds write build/.librmm_dir from CIBW_BEFORE_BUILD. + # publish.yml symlinks runtime libs so auditwheel excludes them. if(DEFINED ENV{RSC_LIBRMM_DIR} AND EXISTS "$ENV{RSC_LIBRMM_DIR}/lib64/cmake/rmm/rmm-config.cmake") set(_rsc_librmm_marker "$ENV{RSC_LIBRMM_DIR}") elseif(EXISTS "${CMAKE_SOURCE_DIR}/build/.librmm_dir") @@ -120,16 +118,8 @@ if (RSC_BUILD_EXTENSIONS) find_package(rmm CONFIG REQUIRED) endif() - # CCCL 3.3.0 (shipped by RAPIDS 26.04) declares the - # cudaDevAttrHostNumaMemoryPoolsSupported device-attribute specialization under - # `_CCCL_CTK_AT_LEAST(12, 6)`, but the CUDA runtime only added that enum in 12.9. - # So compiling the RMM/CCCL-including TUs (the Wilcoxon scratch allocator) against - # a CUDA 12.6-12.8 toolkit fails with a cryptic - # `error: the global scope has no "cudaDevAttrHostNumaMemoryPoolsSupported"`. - # CCCL fixed the guard to `_CCCL_CTK_AT_LEAST(12, 9)` after 3.3.0 (cccl PR #7838), - # so RAPIDS >= 26.06 (CCCL > 3.3.0) closes the gap -- only flag the buggy CCCL. - # Fail fast with an actionable message. Prebuilt wheels are unaffected: they are - # built on CUDA 12.2 (below the guard), so the enum is never referenced. + # CCCL 3.3.0 gates cudaDevAttrHostNumaMemoryPoolsSupported too loosely. + # Fail fast for CUDA 12.6-12.8 source builds with that buggy CCCL. set(_rsc_cccl_buggy_numa_guard TRUE) if (DEFINED CCCL_VERSION AND CCCL_VERSION VERSION_GREATER 3.3.0) set(_rsc_cccl_buggy_numa_guard FALSE) @@ -196,10 +186,8 @@ function(add_nb_cuda_module target src) endif() endfunction() -# An RMM-backed nanobind CUDA module: add_nb_cuda_module plus the shared RMM -# scratch allocator (rmm_scratch.cu) and the rmm::rmm link. Installed wheels -# resolve RAPIDS runtime libs from sibling Python packages; editable source-tree -# imports still have the _cuda/__init__.py preload fallback. +# RMM-backed nanobind CUDA module: normal module plus shared scratch allocator. +# Wheels use sibling RAPIDS packages; editable imports still preload fallbacks. function(add_rmm_cuda_module target src) add_nb_cuda_module(${target} ${src}) if (RSC_BUILD_EXTENSIONS) diff --git a/pyproject.toml b/pyproject.toml index 9b9c87fa..56611f61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,14 +3,8 @@ requires = [ "scikit-build-core>=0.10", "nanobind>=2.0.0", "setuptools-scm>=8", - # librmm headers/CMake config are needed at build time for Wilcoxon. - # Generic isolated source builds default to CUDA 12. CUDA wheel builds - # rewrite this to the matching cu12/cu13 package; CUDA 13 source builds - # should build in an existing RAPIDS env with --no-build-isolation. - # 25.12 floor: the Wilcoxon scratch allocator uses the resource-ref API - # (get_current_device_resource_ref().allocate_sync), and the flat - # header path only exists from 25.12 - # onward (25.10 kept it under rmm/mr/device/). Builds through RMM 26.06+. + # Wilcoxon links librmm at build time; generic isolated builds use CUDA 12. + # 25.12+ provides the resource-ref API and flat RMM header path we use. "librmm-cu12>=25.12", ] build-backend = "scikit_build_core.build" @@ -173,13 +167,8 @@ sdist.include = [ "src/rapids_singlecell/_version.py" ] # Use abi3audit to catch issues with Limited API wheels [tool.cibuildwheel.linux] -# Exclude CUDA libs by SONAME glob (auditwheel >= 6.2): suffixes are version- -# dependent (cusolver is libcusolver.so.11 on CUDA 12 but .12 on CUDA 13), so a -# `.*` glob stays correct across CUDA majors where hardcoded `.12`/`.13` would -# miss variants and bundle ~186MB (cublasLt, cusparse, nvJitLink) that CuPy -# provides at runtime. librmm.so / librapids_logger.so come from the librmm / -# rapids_logger wheels. Keep this list in sync with CIBW_REPAIR_WHEEL_COMMAND in -# .github/workflows/publish.yml (that env var overrides this block in CI). +# Exclude CUDA/RAPIDS runtime libs provided by dependency wheels. +# Keep in sync with CIBW_REPAIR_WHEEL_COMMAND in publish.yml. repair-wheel-command = [ "auditwheel repair --exclude 'libcublas.so.*' --exclude 'libcublasLt.so.*' --exclude 'libcudart.so.*' --exclude 'libcusolver.so.*' --exclude 'libcusparse.so.*' --exclude 'libnvJitLink.so.*' --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}", "pipx run abi3audit --strict --report {wheel}", diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 2b93a142..d4f70d12 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -56,10 +56,7 @@ def _preload_rapids_runtime_libs() -> None: - """Pre-load ``librmm`` / ``rapids_logger`` so the extensions' ``DT_NEEDED`` - soname deps resolve regardless of import order (the editable-install - ``RUNPATH`` is unreliable). Best-effort: absent wheels (docs builds) skip. - """ + """Pre-load RAPIDS runtime libs so extension ``DT_NEEDED`` deps resolve.""" for mod in ("librmm", "rapids_logger"): try: importlib.import_module(mod).load_library() @@ -78,9 +75,8 @@ def __getattr__(name: str): # Extension genuinely absent (docs/no-GPU): degrade to None. return None except ImportError as exc: - # Present but failed to load (ABI/toolkit mismatch, missing .so, rmm - # symbol-ordering): surface with context, don't return None and crash - # later with a cryptic ``'NoneType' has no attribute ...``. + # Present but failed to load: surface ABI/toolkit/lib errors now. + # Returning None would cause a later cryptic attribute error. msg = ( f"Failed to load compiled CUDA extension {name!r}: {exc}. " "Ensure a matching rapids-singlecell-cuXX wheel (and librmm) is " diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 855d36a9..dc27d4f1 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -20,9 +20,8 @@ inline void cuda_check_last_error(const char* kernel_name) { #define CUDA_CHECK_LAST_ERROR(kernel_name) cuda_check_last_error(#kernel_name) -/// Check a cudaError_t returned directly by a CUDA/CUB API call (vs. -/// CUDA_CHECK_LAST_ERROR which inspects state after a launch), so a failed -/// call surfaces with a clear label instead of as corrupted output later. +/// Check a cudaError_t returned directly by a CUDA/CUB API call. +/// Failed calls surface with a clear label instead of corrupted output later. inline void cuda_check(cudaError_t err, const char* what) { if (err != cudaSuccess) { throw std::runtime_error(std::string(what) + @@ -31,8 +30,7 @@ inline void cuda_check(cudaError_t err, const char* what) { } /// Validate a binding-argument precondition (array dims vs. scalar shapes). -/// Throws std::invalid_argument so a mismatch is a clean Python error instead -/// of an out-of-bounds kernel launch. +/// Mismatches become clean Python errors, not out-of-bounds launches. inline void nb_require(bool cond, const char* what) { if (!cond) { throw std::invalid_argument( @@ -40,13 +38,8 @@ inline void nb_require(bool cond, const char* what) { } } -/// Per-axis cached cap on `gridDim.{x,y,z}`. These differ in CUDA: -/// gridDim.x: 2^31-1 on CC 3.0+ -/// gridDim.y: 65535 on most GPUs -/// gridDim.z: 65535 -/// Newer hardware may relax these; we read at runtime and cache per device. -/// Returns a 3-element array indexed by 0=x, 1=y, 2=z. Multi-GPU safe via -/// thread-local cache keyed on the active device. +/// Per-axis cached cap on `gridDim.{x,y,z}`; y/z are often only 65535. +/// Runtime per-device cache keeps this multi-GPU safe. inline const int* max_grid_dims() { static thread_local int cached_dev = -1; static thread_local int cached[3] = {65535, 65535, 65535}; // safe fallback @@ -73,15 +66,8 @@ inline int max_grid_dim_z() { return max_grid_dims()[2]; } -/// Grid-stride cap for kernels whose total work `nwork` (e.g. nnz, n_cells * -/// n_genes) may exceed what a single grid launch can cover. Pair with a -/// grid-strided loop inside the kernel: -/// -/// const long long stride = (long long)blockDim.x * gridDim.x; -/// for (long long i = ...; i < nwork; i += stride) { ... } -/// -/// Defaults to the `gridDim.x` cap. For 2D launches whose strided axis is y, -/// use `strided_grid_y`. Returns at least 1. +/// Grid-stride cap for kernels whose total work exceeds one grid launch. +/// Pair with a grid-strided loop; use `strided_grid_y` for y-axis launches. inline unsigned int strided_grid(long long nwork, int block_size) { const long long max_grid = max_grid_dim_x(); long long ideal = (nwork + block_size - 1) / block_size; @@ -98,9 +84,7 @@ inline unsigned int strided_grid_y(long long nwork, int block_size) { } // GPU array aliases for nanobind bindings, parameterized on device type. -// Bindings are registered for both nb::device::cuda (kDLCUDA = 2) and -// nb::device::cuda_managed (kDLCUDAManaged = 13) so that RMM managed-memory -// allocations are accepted without losing type safety for CPU arrays. +// CUDA and managed-memory variants both preserve CPU/GPU type safety. // C-contiguous (row-major) template @@ -127,10 +111,7 @@ template using host_array_f2 = nb::ndarray, nb::f_contig>; // Register bindings for both regular CUDA and managed-memory arrays. -// Usage: -// template -// void register_bindings(nb::module_& m) { ... } -// NB_MODULE(_foo_cuda, m) { REGISTER_GPU_BINDINGS(register_bindings, m); } +// Each registration function must be templated on `Device`. #define REGISTER_GPU_BINDINGS(func, module) \ func(module); \ func(module) diff --git a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu index 6c07afcb..6064e1bf 100644 --- a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu +++ b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu @@ -9,9 +9,8 @@ namespace { constexpr int GROUP_STATS_BLOCK = 256; -// Benjamini-Hochberg step-up tail: in-place reverse cumulative minimum per row -// of an already BH-scaled, p-value-sorted matrix. NaNs treated as 1.0. One -// block per row, single thread (serial scan). +// Benjamini-Hochberg tail: reverse cumulative min on sorted, BH-scaled rows. +// NaNs become 1.0; one serial thread per row. __global__ void fdr_bh_reverse_cummin_kernel(double* values, const int n_cols) { const int row = blockIdx.x; double running = 1.0; @@ -28,10 +27,8 @@ __global__ void fdr_bh_reverse_cummin_kernel(double* values, const int n_cols) { } } -// Per-group sum / sum-of-squares / nnz over a dense F-order block. group_codes -// maps each row to a group; out-of-range codes are skipped. C-order -// (n_groups x n_cols) outputs accumulated with atomics. Grid-strided so chunks -// larger than the gridDim.x cap are still fully covered. +// Per-group sum/sumsq/nnz over a dense F-order block; invalid groups are +// skipped. Outputs are C-order group x col and grid-strided beyond gridDim.x. __global__ void group_chunk_stats_kernel( const double* block, const int* group_codes, double* group_sums, double* group_sum_sq, double* group_nnz, const int n_rows, const int n_cols, diff --git a/src/rapids_singlecell/_cuda/rmm_scratch.cu b/src/rapids_singlecell/_cuda/rmm_scratch.cu index ea519d2d..474e6227 100644 --- a/src/rapids_singlecell/_cuda/rmm_scratch.cu +++ b/src/rapids_singlecell/_cuda/rmm_scratch.cu @@ -7,12 +7,8 @@ #include "rmm_scratch.h" -// Use the resource-ref API (`get_current_device_resource_ref()` + -// value-semantic `allocate_sync`/`deallocate_sync`) rather than the raw-pointer -// `get_current_device_resource()`. RMM 26.06 removes both the raw-pointer -// accessor and `` as it migrates to the cccl -// `cuda::mr` resource-concept model. The ref form compiles unchanged from -// RMM 25.12 through 26.06 (and onward), so it covers 26.04+. +// Use the RMM resource-ref API; RMM 26.06 removed the raw-pointer accessor. +// The ref form compiles unchanged from RMM 25.12 through 26.06+. void* rmm_allocate(size_t bytes) { try { return rmm::mr::get_current_device_resource_ref().allocate_sync(bytes); @@ -27,21 +23,8 @@ void rmm_deallocate(void* ptr, size_t bytes) { rmm::mr::get_current_device_resource_ref().deallocate_sync(ptr, bytes); } -// `fraction` * the free device memory reported by cudaMemGetInfo. -// -// Deliberately a plain query, NOT a trial-allocation probe. Probing a pool's -// internal free by allocating until it grows permanently RATCHETS the pool -// (RMM pools never shrink): repeated wilcoxon calls would grow it toward the -// whole device and then starve non-pool allocations like cudaStreamCreate -// ("out of memory" on stream creation). cudaMemGetInfo free is correct and -// safe everywhere: -// * Plain cuda: exact. -// * Pool: the memory OUTSIDE the pool's reservation; the pool also serves -// from its internal free, so this is conservative but never over-budgets -// and never grows the pool. The host-streaming paths transfer each nonzero -// once regardless of batch size (per-row cursor gather), so a smaller -// budget only adds a few more passes -- it does not re-stream. -// * Managed/UVM: device-resident free, so sizing to it avoids host spill. +// Plain cudaMemGetInfo budget query, never a pool-probing trial allocation. +// Probing ratchets RMM pools and can starve non-pool allocations like streams. size_t rmm_available_device_bytes(double fraction) { if (fraction <= 0.0) return 0; size_t free_b = 0, total_b = 0; diff --git a/src/rapids_singlecell/_cuda/rmm_scratch.h b/src/rapids_singlecell/_cuda/rmm_scratch.h index 236c94dd..4b41fe15 100644 --- a/src/rapids_singlecell/_cuda/rmm_scratch.h +++ b/src/rapids_singlecell/_cuda/rmm_scratch.h @@ -11,12 +11,8 @@ void* rmm_allocate(size_t bytes); void rmm_deallocate(void* ptr, size_t bytes); -// fraction * cudaMemGetInfo free. A plain query, never a trial-allocation probe -// (probing a pool's internal free ratchets the pool to the whole device and -// starves non-pool allocations like cudaStreamCreate). Conservative under a -// pool but safe across the default cuda resource, a pool, and managed/UVM; the -// host-streaming paths transfer each nonzero once regardless of batch size, so -// a smaller budget only adds passes. Use for every GPU-memory-budget decision. +// fraction * cudaMemGetInfo free; never trial-probe a pool. +// Probing ratchets RMM pools and can starve cudaStreamCreate. size_t rmm_available_device_bytes(double fraction); // Allocation pool for temporary CUDA buffers; frees everything on scope exit. diff --git a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh index 3c01a5de..81a2c519 100644 --- a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh +++ b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh @@ -2,24 +2,11 @@ #include -// ============================================================================ -// Shared CSR/CSC -> {compact CSC, dense} extraction kernels (header-only). -// * compact CSC (csr_scatter_to_csc) -> sparse ranker (nnz only) -// * dense F-order (csr_tile_to_dense, extract) -> dense ranker (all values) -// ============================================================================ - -/** - * Scatter CSR nonzeros into compact CSC for columns [col_start, col_stop). - * write_pos[c - col_start] is column c's prefix-sum offset; threads atomically - * claim destination slots. - * - * PRECONDITION: each row's `indices` sorted ascending -- the binary search for - * col_start and the `break` at col_stop depend on it; unsorted rows would - * silently drop/misplace nonzeros. Python dispatch calls sort_indices() first. - * - * `row_offset` rebases a local-row block to its global row id (out-of-core - * row-streaming OVR path). 0 for full-matrix callers. - */ +// Shared CSR/CSC extraction kernels for compact CSC and dense F-order tiles. +// Callers canonicalize/sort before kernels that binary-search row indices. + +// Scatter CSR nonzeros into compact CSC for columns [col_start, col_stop). +// `row_offset` rebases local row blocks; write_pos is atomically claimed. template __global__ void csr_scatter_to_csc_kernel( const InT* __restrict__ data, const IndexT* __restrict__ indices, @@ -48,12 +35,8 @@ __global__ void csr_scatter_to_csc_kernel( } } -// Single-pass CSR-slice + densify: scatter column window [col_lb, col_ub) into -// a dense (n_cells, col_ub-col_lb) F-order double buffer. -// -// `out` must be pre-zeroed; atomicAdd sums duplicate column indices (like -// scipy's sum_duplicates) -- bit-identical to dense materialization for -// canonical CSR. Output always double; input dtype templated. +// CSR column window [col_lb, col_ub) -> pre-zeroed dense F-order tile. +// atomicAdd preserves summed duplicate semantics for canonicalized CSR. template __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, @@ -81,13 +64,8 @@ __global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, } } -// CSC column-window [col_lb, col_ub) -> dense F-order (double), one block per -// column. NO atomicAdd -- canonical CSC has a unique (col,row) per nonzero (the -// wilcoxon dispatch canonicalizes/sums first). CSC counterpart to -// csr_tile_to_dense_kernel. -// -// `out` must be pre-zeroed. `indptr` indexes columns; pass full-matrix column -// pointers (with col_lb/col_ub) or a window rebased to [0, col_ub-col_lb). +// CSC column window [col_lb, col_ub) -> pre-zeroed dense F-order tile. +// No atomics: canonical CSC has one stored value per (col, row). template __global__ void csc_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, @@ -106,9 +84,8 @@ __global__ void csc_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, } } -// CSR selected rows -> dense F-order. row_ids[tid] = source row; output column -// is (col - col_start), output row is tid. Requires sorted indices (binary -// search + break). Output must be pre-zeroed. +// CSR selected rows -> pre-zeroed dense F-order tile. +// Requires sorted row indices for binary-search + col_stop break. template __global__ void csr_extract_dense_kernel(const T* __restrict__ data, const IndexT* __restrict__ indices, @@ -160,11 +137,8 @@ __global__ void csr_extract_dense_identity_rows_unsorted_kernel( } } -/** - * Extract rows from CSC into dense F-order via a row lookup map. - * row_map[original_row] = output_row_index (or -1 to skip). - * One block per column. Output must be pre-zeroed. - */ +// CSC selected rows -> pre-zeroed dense F-order tile. +// row_map[original_row] gives output row, or -1 to skip. template __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, const IndexT* __restrict__ indices, @@ -186,9 +160,8 @@ __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, } } -// Narrowing element-wise cast (e.g. int64 row indices -> int32 sort values). -// Used only when the input index width exceeds int32; the caller guarantees the -// values fit the destination type (row/col positions < 2^31). +// Narrowing element-wise cast, used only when input index width exceeds int32. +// Caller guarantees row/column positions fit the destination type. template __global__ void cast_array_kernel(const SrcT* __restrict__ src, DstT* __restrict__ dst, size_t n) { diff --git a/src/rapids_singlecell/_cuda/streaming/streaming.cuh b/src/rapids_singlecell/_cuda/streaming/streaming.cuh index fe9fcebf..ebb9f4c9 100644 --- a/src/rapids_singlecell/_cuda/streaming/streaming.cuh +++ b/src/rapids_singlecell/_cuda/streaming/streaming.cuh @@ -31,9 +31,8 @@ static inline int host_worker_count() { return (int)std::min(hw ? hw : 4u, 32u); } -// Run fn(chunk, r0, r1) over a partition of [0, n); `chunk` = 0-based worker -// index. fn runs concurrently: read-only shared state, disjoint output ranges -// (keyed by chunk or [r0,r1)). Returns chunks used; serial for small n. +// Run fn(chunk, r0, r1) over partitions of [0, n), serial for small n. +// Concurrent callers must use read-only shared state and disjoint outputs. template static inline int host_parallel_chunks(int n, F fn) { if (n <= 0) return 0; @@ -157,9 +156,8 @@ static inline int clamp_streams_by_budget(int n_streams, return n_streams; } -// Scatter a [rows, sb_cols] device sub-batch (row-major doubles, src stride -// sb_cols) into `dst` (stride n_cols). `dst` must point at the dest column -// offset (e.g. out + col). +// Scatter row-major [rows, sb_cols] into destination stride n_cols. +// `dst` must already point at the destination column offset. static inline void scatter_cols_2d(double* dst, const double* src, int rows, int n_cols, int sb_cols, cudaStream_t stream) { @@ -168,10 +166,8 @@ static inline void scatter_cols_2d(double* dst, const double* src, int rows, cudaMemcpyDeviceToDevice, stream); } -// Halve sub_batch_cols until the densest window holds <= cap nonzeros, keeping -// every batch's nnz within int32 for CUB and bounding per-stream transpose/sort -// scratch. col_nnz(i) = nnz of column i. Worst case returns 1 (single column, -// nnz <= n_rows). +// Halve sub_batch_cols until the densest window holds <= cap nonzeros. +// Keeps CUB item counts and per-stream scratch bounded; worst case returns 1. template static inline int cap_sub_batch_by_nnz(int n_cols, int sub_batch_cols, size_t cap, ColNnz col_nnz) { @@ -359,10 +355,8 @@ struct HostRegisterGuard { if (p && bytes > 0) { cudaError_t err = cudaHostRegister(p, bytes, flags); if (err != cudaSuccess) { - // Already-registered = owned elsewhere; use it without - // unregistering. Other failures make mapped reads unsafe, so - // surface them -- unless best_effort (pin is only a speedup; - // unpinned H2D still works). + // Already-registered memory is owned elsewhere; use as-is. + // Other failures are fatal unless pinning is only a speedup. if (err == cudaErrorHostMemoryAlreadyRegistered || best_effort) { cudaGetLastError(); // clear sticky error flag @@ -396,11 +390,8 @@ struct HostRegisterGuard { } }; -// RAII for CUDA streams/events: reclaim on every path (incl. exception unwind). -// Stream dtor SYNCHRONIZES before destroying. CRITICAL ordering: declare the -// RmmScratchPool BEFORE these guards so streams (destroyed first) drain -// in-flight kernels before the pool (destroyed last) frees the scratch they -// read. +// RAII for CUDA streams/events: stream destruction synchronizes first. +// Declare RmmScratchPool before guards so streams drain before scratch frees. struct ScopedCudaStream { cudaStream_t stream = nullptr; @@ -517,9 +508,8 @@ struct PinnedRingArray { } }; -// Per-slot pinned host staging with events, so CPU materialization into one -// slot can overlap GPU consumption of another. All arrays have the same item -// capacity; use a second ring for differently-sized metadata such as indptr. +// Per-slot pinned host staging with events for CPU/GPU overlap. +// Arrays share item capacity; use another ring for differently-sized metadata. template struct PinnedRing { std::tuple...> arrays; @@ -583,9 +573,8 @@ __global__ void fill_linear_offsets_kernel(int* __restrict__ out, if (i <= n_segments) out[i] = i * stride; } -/** Rebase a slice of indptr: out[i] = indptr[col+i] - indptr[col]. Grid-strided - * (arbitrary `count`). Templated so 64-bit global indptrs produce 32-bit - * pack-local indptrs (per-pack nnz fits int32 via the memory budget). */ +/** Rebase indptr slice to a local origin, grid-strided for arbitrary count. + * 64-bit global indptrs may produce 32-bit pack-local indptrs. */ template __global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, IdxOut* __restrict__ out, int col, @@ -594,10 +583,8 @@ __global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); } -// Threaded host gather of selected rows into compact staging (f32 vals + int32 -// cols) at disjoint per-row offsets (compact_indptr - base) -> race-free. -// No-pin alternative to the mapped gather kernel: only the compacted slice -// crosses the bus. +// Threaded selected-row gather into compact staging at disjoint offsets. +// No-pin alternative: only the compacted slice crosses the bus. template static void host_gather_rows_compact_as( @@ -630,9 +617,8 @@ static void host_gather_rows_compact(const InT* h_data, const IndexT* h_indices, n_target, stage_vals, stage_cols); } -// Threaded host cast-copy of a contiguous nnz slice into staging (f32 + int32). -// CSC analogue of host_gather_rows_compact: contiguous column batch, no gather. -// nnz fits int32 (batch-bounded). +// Threaded host cast-copy of a contiguous nnz slice into staging. +// CSC analogue of row gather: contiguous column batch, bounded int32 nnz. template static void host_copy_slice_as(const InT* h_data, const IndexT* h_indices, @@ -662,9 +648,8 @@ static void host_cast_copy_slice(const InT* h_data, const IndexT* h_indices, stage_cols); } -// Threaded host gather of selected dense rows and a contiguous column window -// into an F-order staging tile. f_order describes the full source matrix; the -// staged output is always [n_window_rows, n_window_cols] in column-major order. +// Threaded host gather of selected dense rows and contiguous columns. +// Output staging is always F-order [n_window_rows, n_window_cols]. template static void host_materialize_dense_rows_window_as( const InT* h_X, bool f_order, int n_full_rows, int n_full_cols, @@ -764,9 +749,8 @@ static void host_materialize_csr_cols_window( [&](int col) { return col_map[col]; }, stage_vals, stage_cols); } -// Optimized CSR -> contiguous-column-window materialization for sorted rows and -// ascending column batches. The per-row cursor means each nonzero is examined -// once across the full stream. +// Optimized CSR -> contiguous-column-window materialization for sorted rows. +// The per-row cursor examines each nonzero once across the full stream. template static int host_materialize_csr_column_interval_cursor_as( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index db5f492c..2e963df8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -5,12 +5,9 @@ #include "wilcoxon_block_reduce.cuh" #include "wilcoxon_ovr_tie_walk.cuh" -// Dense OVR rank kernel. sorted_vals/sorted_row_idx are F-order arrays from a -// segmented SortPairs. One block per column; walks sorted tie runs and -// accumulates average ranks per group without materializing a rank matrix. -// The `use_gmem` flag (set by ovr_smem_config) selects shared- vs -// global-memory group accumulators -- CRITICAL: the use_gmem path is REQUIRED -// when n_groups is large (does NOT fit in smem) and must not be removed. +// Dense OVR rank kernel over sorted F-order columns; no rank matrix +// materialized. CRITICAL: `use_gmem` is required when n_groups exceeds +// shared-memory capacity. __global__ void rank_sums_from_sorted_kernel( const float* __restrict__ sorted_vals, const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index ff6a2f1a..98ea3f06 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -5,9 +5,8 @@ #include "wilcoxon_block_reduce.cuh" #include "wilcoxon_fast_common.cuh" -// Bitonic sort of `n` floats in shared memory, ascending. `n` MUST be a power -// of two; pad the tail with +INF before calling. Grid-stride: any blockDim -// works. +// Bitonic sort of power-of-two `n` floats in shared memory, ascending. +// Pad the tail with +INF before calling; any blockDim works. __device__ __forceinline__ void bitonic_sort_smem(float* s, int n) { for (int k = 2; k <= n; k <<= 1) { @@ -28,9 +27,8 @@ __device__ __forceinline__ void bitonic_sort_smem(float* s, int n) { } } -// Sorted-array bounds over [lo, hi). lower: first idx with arr[idx] >= v (count -// of elements < v). upper: first idx with arr[idx] > v (count <= v). Advanced -// `lo` exploits per-thread-stride monotonicity; works on global or shared arr. +// Sorted-array bounds over [lo, hi): lower is first >= v, upper first > v. +// Advanced `lo` exploits monotonic strides; global/shared arrays both work. __device__ __forceinline__ int sorted_lower_bound(const float* arr, int lo, int hi, float v) { @@ -56,9 +54,8 @@ __device__ __forceinline__ int sorted_upper_bound(const float* arr, int lo, return lo; } -// Mid-rank of `v` in the merged (ref, grp) arrays. Advances the four -// incremental bounds (pass 0,0,0,0 for a fresh search); reports per-array equal -// counts for tie correction. +// Mid-rank of `v` in merged (ref, grp) arrays with incremental bounds. +// Also reports equal counts per array for tie correction. struct OvoRank { double mid_rank; int n_eq_ref; @@ -90,12 +87,8 @@ __device__ __forceinline__ OvoRank ovo_mid_rank(const float* ref, int n_ref, return r; } -// Amortized tie correction for LARGE/HUGE bands (group is SORTED). Adds only -// the group-only / ref-overlap delta on the precomputed ref base -// ref_tie_sums[col], like MEDIUM. Iterates the group's UNIQUE values only (one -// ref binary search each) so the ref is NOT rescanned per group: O(n_grp_unique -// * log n_ref) vs O(n_ref)/group. Bit-identical: same per-value (t^3 - t) -// terms, reassociated against the shared ref base. +// Amortized tie correction for sorted LARGE/HUGE groups. +// Only unique group values update the precomputed ref tie base. __device__ __forceinline__ void compute_tie_delta_sorted_grp( const float* ref_col, int n_ref, const float* grp_col, int n_grp, double ref_base, double* warp_buf, double* out) { @@ -124,13 +117,9 @@ __device__ __forceinline__ void compute_tie_delta_sorted_grp( *out = finalize_tie_corr(n_ref + n_grp, ref_base + tie); } -// No-tie fast path (tie_correct=False, default). Ranks each group value against -// the sorted REFERENCE only, via the Mann-Whitney U identity: -// R_g = n_grp(n_grp+1)/2 + Σ_{g values}(#ref_below + 0.5·#ref_equal) -// Group-internal ranks collapse to the closed form, so the group needs NO sort -// (each value binary-searches the sorted ref) -- skips the group segmented -// sort, ~half of dense-OVO time. rank_sums are exact half-integers => matches -// the tiered path bit-for-bit. Grid (n_cols, n_groups). grp_dense is UNSORTED. +// No-tie fast path: group-internal ranks collapse to the U closed form. +// Each unsorted group value binary-searches the sorted reference; no group +// sort. __global__ void ovo_rank_dense_vs_ref_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, @@ -163,13 +152,8 @@ __global__ void ovo_rank_dense_vs_ref_kernel( } } -// LARGE/HUGE pre-sorted rank kernel. Grid (n_cols, n_groups); each thread -// carries lower/upper bounds across its stride (sorted-grp_col monotonicity). -// SMEM_SORT=true (LARGE, groups <= OVO_LARGE_MAX): load unsorted group into -// dynamic smem (large_padded floats) + bitonic-sort. =false (HUGE): read a -// CUB-segmented-sorted group from global. Post-sort body (incremental mid-ranks -// + amortized ref-tie delta) is shared. Each group owns its rank_sums/tie_corr -// row, so size-gated co-launch (skip_n_grp_le) never aliases. +// LARGE/HUGE rank kernel; LARGE smem-sorts, HUGE reads CUB-sorted groups. +// Post-sort mid-rank/tie body is shared and each group owns its output row. template __global__ void ovo_rank_sorted_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_in, @@ -230,9 +214,8 @@ __global__ void ovo_rank_sorted_kernel( &tie_corr[grp * n_cols + col]); } -// MEDIUM-band helper: tie contribution of the sorted reference alone (one block -// per column). The rank kernels use this base and add only group-only/overlap -// deltas from the group values. +// MEDIUM tie helper: sorted-reference contribution, one block per column. +// Rank kernels add only group-only/ref-overlap deltas. __global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, double* __restrict__ ref_tie_sums, int n_ref, int n_cols) { @@ -257,10 +240,8 @@ __global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, if (threadIdx.x == 0) ref_tie_sums[col] = total; } -// MEDIUM-band fused kernel: no-sort direct rank for groups in (skip_n_grp_le, -// max_n_grp_le]. Ranks = ref binary searches + an in-group scan over unsorted -// shared values. Tie correction starts from ref_tie_sums[col] and adds only -// group-only / ref-overlap deltas. +// MEDIUM fused kernel: ref binary searches plus in-group scan over smem values. +// Tie correction starts from ref_tie_sums[col] and adds only group deltas. __global__ void ovo_rank_medium_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, const int* __restrict__ grp_offsets, @@ -337,6 +318,6 @@ __global__ void ovo_rank_medium_kernel( finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie_delta); } -// WARP (≤32) and SMALL (33–64) tiers were removed; MEDIUM is now the smallest -// tier, covering all groups ≤ OVO_MEDIUM_MAX. Removed kernels archived with -// restore steps in .claude/wilcoxon-warp-small-tiers-removed.md. +// WARP/SMALL tiers were removed; MEDIUM now covers all groups <= +// OVO_MEDIUM_MAX. Restore notes live in +// .claude/wilcoxon-warp-small-tiers-removed.md. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 01b4eb7f..01cbee41 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -122,11 +122,9 @@ static void launch_ovr_rank_dense_streaming( sync_streams(streams, "dense OVR streaming rank"); } -// Host-streaming dense OVR: pinned-host multi-stream pipeline feeding the dense -// sort+rank above. Reads each sub-batch into an F-order device block (full -// array never transposed): F-order = contiguous memcpy, C-order = strided 2D -// copy. Keys cast to f32; group sums (+nnz) accumulated in f64 from native -// staging. +// Host-streaming dense OVR: pinned multi-stream batches into F-order device +// slabs. F-order copies contiguous; C-order uses 2D copy; stats accumulate in +// f64. template static void launch_ovr_rank_dense_host_streaming( const T* h_X, bool f_order, const int* group_codes, double* rank_sums, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh index d92238f2..156648c7 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh @@ -10,9 +10,8 @@ __device__ __forceinline__ double warp_reduce_sum(double v) { return v; } -// Block-wide sum of `val` across all threads. `warp_buf` is shared scratch -// holding one double per warp (>= ceil(blockDim.x / 32) <= 32). Result is -// returned on thread 0 (lane 0 of warp 0); other threads get 0.0. +// Block-wide sum of `val` using one shared double per warp. +// Result is returned on thread 0; other threads get 0.0. __device__ __forceinline__ double wilcoxon_block_sum(double val, double* warp_buf) { val = warp_reduce_sum(val); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 02059ecc..8e578c13 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -22,32 +22,26 @@ constexpr int BEGIN_BIT = 0; constexpr int END_BIT = 32; // Scratch slots for warp-level reduction (one slot per warp, 32 warps max). constexpr int WARP_REDUCE_BUF = 32; -// MEDIUM band cap: groups up to this size use unsorted O(n^2) in-group-count -// rank (no smem sort). Tier dispatch: make_ovo_tier_plan. +// MEDIUM band cap: groups up to this size use unsorted O(n^2) in-group rank. +// Tier dispatch: make_ovo_tier_plan. constexpr int OVO_MEDIUM_MAX = 512; // LARGE band cap (fused smem-sort kernel); beyond it -> HUGE (CUB segmented // sort). constexpr int OVO_LARGE_MAX = 2500; -// Per-stream dense slab budget (f32 items): 128M*4B=512MB slab + 512MB sorted -// copy ≈ 1GB/stream. Sub-batching keeps (n_g * eff_sb_cols) <= this. +// Per-stream dense slab budget: 128M f32 items plus sorted copy ~= 1GB/stream. +// Sub-batching keeps (n_g * eff_sb_cols) within this. constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; -// Budget-aware OVO-host pack sizing. Per-stream device scratch that does NOT -// scale with pack nnz: dense + sorted slabs (each <= GROUP_DENSE_BUDGET) plus -// rank/tie/seg/cub headroom. Reserved per target stream when bounding pack nnz -// so the resident packs + sorted ref cache fit device free. +// Budget-aware OVO-host pack sizing for fixed per-stream scratch. +// Reserves dense/sorted slabs plus rank/tie/seg/CUB headroom. constexpr size_t OVO_PACK_FIXED_PER_STREAM = 4 * GROUP_DENSE_BUDGET_ITEMS * sizeof(float); // ~2 GB // Floor for the budget-derived pack-nnz cap: avoid pathological over-splitting // into thousands of tiny packs when device memory is very tight. constexpr size_t OVO_MIN_PACK_NNZ = 64 * 1024 * 1024; // 64M nnz -// Host->device staging-ring slot cap (nnz). Bounds the page-locked footprint: -// a pack's device buffer is filled in row-blocks of <= this many nonzeros, so -// the cold pin stays small instead of seconds when pack nnz is large. 32M nnz -// (128MB vals + 128MB cols/slot) is the joint sweet spot across scales: it -// crushes the whole-pack pin at 2M (~2.7x) yet stays well clear of a sharp -// large-scale slowdown seen with much smaller blocks at multi-billion nnz. +// H2D staging-ring slot cap: keeps page-locked footprint bounded per row-block. +// 32M nnz was the best compromise across small and multi-billion-nnz scales. constexpr size_t STAGE_RING_NNZ_CAP = 32 * 1024 * 1024; // Query CUB segmented-radix-sort scratch size. Float keys, int values/offsets. @@ -103,16 +97,12 @@ static inline void cub_segmented_sortpairs( what); } -// Universal CUDA static per-block shared-memory floor; safe fallback if the -// device query fails. +// Universal CUDA static per-block shared-memory floor. +// Safe fallback if the device query fails. constexpr size_t WILCOXON_FALLBACK_SMEM_PER_BLOCK = 48 * 1024; -// CRITICAL: per-block smem limit (cached per device) powering every smem/gmem -// and tier decision (ovr_smem_config, sparse_ovr_smem_config, -// cast_accumulate_smem_config, make_ovo_tier_plan). DO NOT hardcode a smem -// value in place of this call -- gmem-fallback thresholds (e.g. sparse OVR -// ~3056 groups) auto-scale with the GPU. Falls back to 48 KB if the query -// fails. +// CRITICAL: cached per-device smem limit drives every smem/gmem/tier decision. +// Do not hardcode thresholds; sparse OVR fallback auto-scales with the GPU. static inline size_t wilcoxon_max_smem_per_block() { int device = 0; if (cudaGetDevice(&device) != cudaSuccess) { @@ -140,9 +130,8 @@ static inline int round_up_to_warp(int n) { return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; } -/** Per-row stats codes for a pack of K groups. From pack_grp_offsets (size K+1, - * relative to pack start), write stats_codes[r] = base_slot + group_idx(r) via - * binary search over the K+1 offsets. */ +/** Per-row stats codes for a pack of K groups. + * Writes stats_codes[r] = base_slot + group_idx(r) by offset binary search. */ __global__ void fill_pack_stats_codes_kernel( const int* __restrict__ pack_grp_offsets, int* __restrict__ stats_codes, int K, int base_slot) { @@ -160,9 +149,8 @@ __global__ void fill_pack_stats_codes_kernel( stats_codes[r] = base_slot + lo; } -// Per-group stats over an already-compact CSR (accumulate half of the mapped -// gather kernel, decoupled for host-staged data). slot = stats_codes[r] or -// fixed_slot; slot outside [0,n_groups_stats) is skipped. +// Per-group stats over compact CSR, decoupled for host-staged data. +// Slot comes from stats_codes[r] or fixed_slot; out-of-range slots are skipped. __global__ void csr_compact_accumulate_kernel( const float* __restrict__ d_data_f32, const int* __restrict__ d_indices, const int* __restrict__ d_indptr, const int* __restrict__ d_stats_codes, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 4eae5a45..2a25e26c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -1,10 +1,7 @@ #pragma once -/** - * CSR-direct OVO streaming pipeline. Reference rows are extracted and sorted - * once across all columns; each group sub-batch ranks against that cached slice - * (mirrors the host-CSR path, avoids per-column reference re-extraction+sort). - */ +/** CSR-direct OVO pipeline: cache sorted reference columns once. + * Group sub-batches rank against that cache, matching the host-CSR path. */ template static void ovo_streaming_csr_impl( const float* csr_data, const IndexT* csr_indices, const IndptrT* csr_indptr, @@ -45,8 +42,8 @@ static void ovo_streaming_csr_impl( } int ref_cache_cols = std::min(n_cols, (int)max_ref_cols); { - // Ref cache = 2 floats/col/ref-row; size to ~1/3 of the allocator - // budget, leaving room for group buffers. + // Ref cache uses dense+sorted floats per column/ref row. + // Size to ~1/3 allocator budget, leaving room for group buffers. size_t bytes_per_col = (size_t)n_ref * sizeof(float) * 2; size_t target_bytes = rmm_available_device_bytes(1.0 / 3.0); if (bytes_per_col > 0 && target_bytes >= bytes_per_col) { @@ -73,9 +70,8 @@ static void ovo_streaming_csr_impl( cub_temp_bytes = cub_grp_bytes; } - // Clamp streams to the per-stream scratch budget (mirrors host OVO): the - // group slab scales with cell count, so a fixed stream count would OOM at - // scale. Ref cache is allocated separately, so reserve its footprint first. + // Clamp streams to budget: group slabs scale with cell count. + // Ref cache is allocated separately, so reserve its footprint first. { size_t per_stream = sub_grp_items * sizeof(float) + @@ -232,10 +228,8 @@ static void ovo_streaming_csr_impl( } } -/** - * CSC-direct OVO streaming pipeline. Like the CSR variant, but extracts rows - * via lookup maps to operate on native CSC input without converting the matrix. - */ +/** CSC-direct OVO pipeline: extracts rows via lookup maps. + * Operates on native CSC input without converting the matrix. */ template static void ovo_streaming_csc_impl( const float* csc_data, const IndexT* csc_indices, const IndptrT* csc_indptr, @@ -287,9 +281,8 @@ static void ovo_streaming_csc_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - // Clamp streams to the per-stream scratch budget (mirrors host OVO): the - // ref/group slabs scale with cell counts, so a fixed count would OOM at - // scale. + // Clamp streams to per-stream scratch budget. + // Ref/group slabs scale with cell counts, so fixed counts can OOM. { size_t per_stream = 2 * sub_ref_items * sizeof(float) + diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index 4812881a..a8aae1cb 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -96,11 +96,8 @@ static OvoHostCsrPackPlan plan_ovo_host_csr_packs( return plan; } -/** - * Host-streaming CSC OVO pipeline: CSC on host, only each column sub-batch is - * sent to GPU; row maps + group offsets uploaded once; results written back - * per sub-batch. - */ +/** Host-streaming CSC OVO: send only each column sub-batch to GPU. + * Row maps/group offsets upload once; results scatter per sub-batch. */ template static void ovo_streaming_csc_host_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, @@ -361,12 +358,8 @@ static void ovo_streaming_csc_host_impl( sync_streams(streams, "wilcoxon streaming"); } -/** - * Host CSR OVO pipeline (no full-matrix page-lock). - * - * Ref + each group pack are host-gathered into pinned staging, bulk-H2D'd, - * then extract dense -> segmented sort -> rank vs cached sorted ref -> scatter. - * Packs round-robin across N_STREAMS; per-slot events overlap gather + compute. +/** Host CSR OVO pipeline with no full-matrix page-lock. + * Pinned pack staging feeds dense extract, sort, rank vs cached ref, scatter. */ template static void ovo_streaming_csr_host_impl( @@ -433,9 +426,8 @@ static void ovo_streaming_csr_host_impl( ref_stream); } - // No full-matrix page-lock (the 280GB cudaHostRegister was ~7s/call). The - // gather reads pageable CSR and transfers only the compacted slice; just - // the small pinned staging buffers are registered. + // No full-matrix page-lock: large cudaHostRegister was seconds per call. + // Gather reads pageable CSR and pins only small staging buffers. (void)n_full_rows; // Pinned staging for the reference gather (compacted f32 vals + int32 @@ -457,10 +449,9 @@ static void ovo_streaming_csr_host_impl( cudaMemcpy(d_grp_offsets_full, h_grp_offsets, (n_test + 1) * sizeof(int), cudaMemcpyHostToDevice); - // Phase 1: Ref setup (scoped scratch, ref_sorted persists). - // The [n_ref × n_cols] sorted cache is built one COLUMN CHUNK at a time so - // each CUB segmented sort stays within int32 and the extract scratch is - // chunk-bounded; this is what lets n_ref × n_cols > INT_MAX work. + // Phase 1: ref setup with scoped scratch; sorted cache persists. + // Build by column chunk so CUB item counts and extract scratch stay + // bounded. size_t ref_items = (size_t)n_ref * (size_t)n_cols; if (ref_items > std::numeric_limits::max() / (2 * sizeof(float))) { throw std::runtime_error( @@ -578,12 +569,8 @@ static void ovo_streaming_csr_host_impl( constexpr size_t window_value_bytes = sizeof(WilcoxonSparseWindowDTypes::value_type); - // Clamp streams to the device-memory budget (90% of free). The per-stream - // pack buffers + dense slabs dominate device use, so a fixed stream count - // OOMs at scale / on smaller GPUs. The sorted ref cache + small shared - // arrays are already allocated, so the measured free already excludes them; - // budget is what is left for the per-stream scratch. Fewer streams just - // means less gather/compute overlap, not a re-stream. + // Clamp streams to the post-ref free-memory budget. + // Per-stream pack buffers dominate; fewer streams reduce overlap only. { size_t per_stream = sparse_window_nnz_bytes(max_pack_nnz) + @@ -652,11 +639,9 @@ static void ovo_streaming_csr_host_impl( } } - // Small rolling pinned staging shared across packs: each pack's device - // buffer is filled in row-blocks of <= stage_cap nnz, so the page-locked - // footprint stays small regardless of pack nnz (the whole-pack pin was the - // dominant cost at small/medium scale). Extra slots let the host gather run - // ahead of the in-flight H2Ds. + // Rolling pinned staging fills pack device buffers in <= stage_cap nnz + // blocks. This keeps page-locked footprint small while extra slots overlap + // H2D. size_t stage_cap = std::min(max_pack_nnz, STAGE_RING_NNZ_CAP); int ring_slots = n_streams + 2; HostStagingRing stage(ring_slots, stage_cap); @@ -720,9 +705,8 @@ static void ovo_streaming_csr_host_impl( CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); } - // Host-gather pack rows into the rolling staging in row-blocks (<= - // stage_cap nnz each), H2D each block into the pack's device buffer at - // its nnz offset, then accumulate stats over the full pack. + // Host-gather pack rows into rolling staging blocks, then H2D by + // offset. Stats accumulate once over the full device-resident pack. if (pack.nnz > 0) { IndptrT pack_base = h_grp_indptr_compact[row_start]; int rb0 = 0; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index a7358246..30170651 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -4,10 +4,8 @@ #include "../sparse_extract/sparse_extract.cuh" -/** - * Build CUB segmented-sort ranges for HUGE-band groups. Ranges point into the - * original dense group layout (normal per-group positions). - */ +/** Build CUB segmented-sort ranges for HUGE-band groups. + * Ranges point into the original dense group layout. */ __global__ void build_huge_seg_offsets_kernel( const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, int* __restrict__ begins, int* __restrict__ ends, int n_all_grp, @@ -58,11 +56,8 @@ __global__ void dense_ovo_group_stats_kernel( } } -/** - * Sizing knobs for LARGE-band dispatch: largest group fits in smem -> fused - * bitonic-sort + binary-search kernel per block; else fall back to HUGE band - * (CUB segmented sort + pre-sorted rank kernel). - */ +/** Sizing knobs for LARGE/HUGE dispatch. + * LARGE uses fused smem sort; HUGE uses CUB sort plus pre-sorted rank. */ struct OvoTierPlan { int max_grp_size = 0; bool run_medium = false; // MEDIUM band: any group ≤ OVO_MEDIUM_MAX @@ -73,15 +68,8 @@ struct OvoTierPlan { size_t large_smem = 0; }; -// Single source of truth for OVO tier dispatch (dense + all four sparse OVO -// impls). Scans group sizes once; co-launches by max group size: -// MEDIUM (<=512): ovo_rank_medium_kernel (no sort; O(n^2) in-group count) -// LARGE (<=2500): ovo_rank_sorted_kernel (fused smem bitonic sort) -// HUGE (>2500): CUB segmented sort + ovo_rank_sorted_kernel -// MEDIUM co-launches with the upper tier, which skips groups <= OVO_MEDIUM_MAX -// (skip_n_grp_le). LARGE falls back to HUGE if smem exceeds the per-block -// limit. (WARP/SMALL sub-tiers removed -- -// .claude/wilcoxon-warp-small-tiers-removed.md.) +// Single OVO tier planner shared by dense and all sparse implementations. +// MEDIUM co-launches; LARGE falls back to HUGE if smem exceeds device limits. static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { OvoTierPlan c; for (int g = 0; g < n_groups; g++) { @@ -99,9 +87,8 @@ static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { c.large_tpb = std::min(c.large_padded, MAX_THREADS_PER_BLOCK); // dynamic smem = grp_smem only; warp_buf is static in the kernel. c.large_smem = (size_t)c.large_padded * sizeof(float); - // Device-adapt: if fused-sort buffer exceeds the per-block smem limit, - // fall back to HUGE (no smem cap). Inert at the ~16.6KB threshold; - // guards against threshold/device-limit changes. + // Device-adapt fused-sort smem to the per-block limit. + // If it no longer fits, fall back to HUGE with no smem cap. if (c.large_smem > wilcoxon_max_smem_per_block()) { c.run_large = false; } @@ -157,20 +144,16 @@ struct OvoTierScratch { uint8_t* grp_cub_temp; // HUGE: CUB scratch }; -// SINGLE OVO ranking engine, shared by dense + all four sparse OVO impls -// (host/device CSC/CSR). Given a sorted reference slice and a dense group slice -// for one column sub-batch, runs the size-banded dispatch from `plan` (see -// make_ovo_tier_plan). Callers differ only in how they produce ref_sorted / -// grp_dense. +// Single OVO ranking engine shared by dense and all sparse host/device paths. +// Callers differ only in how they produce ref_sorted and grp_dense. static inline void ovo_dispatch_tiers( const float* ref_sorted, const float* grp_dense, const int* grp_offsets, const OvoTierPlan& plan, const OvoTierScratch& sc, const int* d_sort_group_ids, int n_sort_groups, size_t grp_cub_temp_bytes, int sb_grp_items_actual, int tpb_rank, int n_ref, int n_all_grp, int sb_cols, int n_groups, bool compute_tie_corr, cudaStream_t stream) { - // No-tie fast path (tie_correct=False, the default): rank each group value - // vs the sorted reference only (U-identity), skipping group sort + all - // tiers. grp_dense is unsorted here, which this kernel wants. + // No-tie fast path: rank unsorted group values vs sorted ref (U-identity). + // Skips group sort and all tier kernels. if (!compute_tie_corr) { constexpr int VS_REF_BLOCK = 256; dim3 grid(sb_cols, n_groups); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index e516d35e..fd8e1cdc 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -18,12 +18,8 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, } } -// CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). -// smem-vs-gmem for the DENSE OVR rank kernel. Per-block accumulator is -// (n_groups+32) doubles; over the per-block smem limit (~48 KB) it falls back -// to gmem (use_gmem=true), flipping at roughly n_groups > 6112. Not dead: smem -// mode with an oversized request fails to launch. Limit device-queried via -// wilcoxon_max_smem_per_block(), so it auto-scales. +// CRITICAL: dense OVR gmem fallback is load-bearing for large n_groups. +// Shared-memory thresholds are device-queried; oversized smem would not launch. static size_t ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(n_groups + 32) * sizeof(double); if (need <= wilcoxon_max_smem_per_block()) { @@ -35,15 +31,8 @@ static size_t ovr_smem_config(int n_groups, bool& use_gmem) { return 32 * sizeof(double); } -/** - * CRITICAL — DO NOT REMOVE the gmem branch. Load-bearing path for Perturb-seq / - * pooled-CRISPR DE (n_groups in the thousands). smem-vs-gmem for the sparse OVR - * rank kernel. Per-block accumulator is (2*n_groups+32) doubles (grp_sums + - * grp_nz_count + warp buf); over the per-block smem limit (~48 KB) the kernel - * CANNOT launch in smem mode, so use_gmem=true routes to a caller gmem buffer. - * Flips at roughly n_groups > 3056. Twice mistaken for dead code; it is the - * ONLY path that works at large n_groups. Limit device-queried via - * wilcoxon_max_smem_per_block(), so the threshold auto-scales. +/** CRITICAL: sparse OVR gmem fallback is required for Perturb-seq-scale groups. + * Shared-memory thresholds are device-queried; oversized smem cannot launch. */ static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); @@ -55,10 +44,8 @@ static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { return 32 * sizeof(double); } -/** - * Fill sort values with row indices [0,1,...,n_rows-1] per column. - * Grid: (n_cols,), block: 256 threads. - */ +/** Fill sort values with row indices [0, 1, ..., n_rows-1] per column. + * Grid: (n_cols,), block: 256 threads. */ __global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, int n_cols) { int col = blockIdx.x; @@ -69,13 +56,8 @@ __global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, } } -/** - * Read one dense column-batch (native `T`) into f32 F-order (the layout the - * segmented sort expects); single sub-batch only, full array never transposed. - * f_order=true : staging already F-order -> identity cast. - * f_order=false: staging C-order; read into the F-order slot while casting. - * Grid-stride over n_rows*sb_cols. - */ +/** Read one dense column batch into f32 F-order for segmented sort. + * F-order is identity cast; C-order reads into F-order while casting. */ template __global__ void dense_block_to_f32_kernel(const T* __restrict__ stg, float* __restrict__ out, int n_rows, @@ -94,13 +76,8 @@ __global__ void dense_block_to_f32_kernel(const T* __restrict__ stg, } } -/** - * Accumulate per-(group, column) sums (+optional nnz) from a dense - * column-batch, reading NATIVE staging in f64 so means match the Aggregate path - * (the f32 cast is only for ranking). One block per column; - * group_sums/group_nnz are this batch's (n_groups x sb_cols) buffers and must - * be pre-zeroed. - */ +/** Accumulate dense batch per-group sums and optional nnz in f64. + * Reads native staging so means match Aggregate; ranking cast is separate. */ template __global__ void dense_group_accumulate_kernel( const T* __restrict__ stg, const int* __restrict__ group_codes, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index e1c4d7ed..8cff570d 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -1,10 +1,7 @@ #pragma once -/** - * Sparse-aware host-streaming CSC OVR pipeline. - * Sorts only stored nonzeros per column: GPU mem O(max_batch_nnz), sort work - * O(nnz) not O(n_rows). - */ +// Host-streaming CSC OVR: sort only stored nonzeros per column. +// GPU memory is O(max_batch_nnz), not O(n_rows * n_cols). template static void ovr_sparse_csc_host_streaming_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, @@ -243,20 +240,9 @@ static void ovr_sparse_csc_host_streaming_impl( sync_streams(streams, "sparse host CSC streaming"); } -// ============================================================================ -// Sparse-aware host-streaming CSR OVR pipeline. -// ============================================================================ - -/** - * Out-of-core OVR for a host CSR too large to stage on the GPU. - * - * PRECONDITION: column indices sorted within each row. A per-row cursor (init - * 0) walks the matrix ONCE: for each ascending column batch [col, col_end) - * every row resumes where the prior batch stopped. Cursor advances - * monotonically, so each nonzero is read + bulk-transferred exactly once (true - * 1x transfer, not per-batch whole-CSR re-streaming). Histogram counted on - * host; full CSR never page-locked (gather reads it on CPU). Single stream. - */ +// Host CSR rowstream OVR for matrices too large to stage on the GPU. +// Sorted rows let cursors advance once, so each nnz is gathered/transferred +// once. template static void ovr_sparse_csr_host_rowstream_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, @@ -272,8 +258,7 @@ static void ovr_sparse_csr_host_rowstream_impl( int tpb = UTIL_BLOCK_SIZE; size_t budget = rmm_available_device_bytes(0.8); - // ---- Phase 0: host column histogram, threaded by row range; each worker - // counts into a private array (no false sharing), merged after. ---- + // Host column histogram; each worker counts privately, then merges. std::vector h_col_counts(n_cols, 0); { int n_workers = host_worker_count(); @@ -288,9 +273,8 @@ static void ovr_sparse_csr_host_rowstream_impl( for (int c = 0; c < n_cols; c++) h_col_counts[c] += local[w][c]; } - // ---- Column batch size: int32 CUB limit + device buffers fit budget. - // Per-nnz: gather mini-CSR (val+col) + CSC accum (val+f32+row) + sort out - // (key+row) + CUB temp. ---- + // Column batches must satisfy int32 CUB limits and device memory budget. + // Per-nnz scratch covers mini-CSR gather, CSC accum, sort output, and CUB. constexpr size_t BYTES_PER_NNZ = 2 * sizeof(InT) // gather val + csc val + 2 * sizeof(float) // f32 key in + out + 3 * sizeof(int) // gather col + 2 rows @@ -318,9 +302,8 @@ static void ovr_sparse_csr_host_rowstream_impl( size_t smem_cast = cast_accumulate_smem_config( n_groups, compute_nnz, compute_totals, cast_use_gmem); - // ---- Host gather staging (pinned for bulk H2D) + per-row cursor. Full CSR - // NOT page-locked: gather reads it on CPU, only compacted slice crosses - // bus. + // Host gather staging is pinned; full CSR stays pageable on CPU. + // Only the compacted column interval crosses the bus. size_t stage_nnz = max_batch_nnz ? max_batch_nnz : 1; PinnedRing gather_stage(1, stage_nnz); PinnedRing indptr_stage(1, (size_t)n_rows + 1); @@ -364,11 +347,9 @@ static void ovr_sparse_csr_host_rowstream_impl( ScopedCudaStream row_stream(cudaStreamDefault); cudaStream_t stream = row_stream.get(); - // ---- One linear column-batched pass. Cursor advances monotonically - // (sorted indices + ascending batches): each nonzero read/transferred once, - // no whole-matrix re-streaming. Threaded gather: count each row's run, - // prefix-sum to per-row offsets, copy rows into disjoint staging ranges. - // ---- + // One ascending column pass; sorted-row cursors make transfer one-shot. + // Threaded gather counts row runs, prefix-sums, then copies disjoint + // ranges. std::vector g_count(n_rows); int col = 0; for (int b = 0; b < n_batches; b++) { @@ -463,12 +444,8 @@ static void ovr_sparse_csr_host_rowstream_impl( cuda_check(cudaStreamSynchronize(stream), "rowstream sync"); } -/** - * Host CSR variant of the sparse OVR stream. - * CSR stays in host memory; columns counted once, then mapped pinned arrays - * feed bounded per-column-batch CSR->CSC scatter on the GPU -- avoids both a - * full sparse upload and any whole-matrix CSR->CSC conversion. - */ +// Host CSR sparse OVR stream: keep CSR on host and batch CSR->CSC scatter. +// Avoids full sparse upload and whole-matrix CSR->CSC conversion. template static void ovr_sparse_csr_host_streaming_impl( const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, @@ -508,10 +485,8 @@ static void ovr_sparse_csr_host_streaming_impl( cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); - // Stage indices on device when they fit so histogram + scatter read at HBM - // speed not over the bus. Both need indices, so staged first; data (equal - // size) staged later only if it fits too. Bulk pageable copy is - // driver-staged -- no host registration. + // Stage indices first when they fit so histogram/scatter read at HBM speed. + // Data is staged too only if data plus one stream buffer still fits. IndexT* d_indices = nullptr; bool indices_staged = total_nnz > 0 && idx_bytes <= budget / 2; if (total_nnz > 0) { @@ -530,9 +505,8 @@ static void ovr_sparse_csr_host_streaming_impl( } } - // ---- Phase 0: per-column nnz counts on the GPU ---- - // CSR has no column structure -> CPU count is a serial pass over every nnz. - // Histogram device-accessible indices; only n_cols counts come back. + // Count per-column nnz on GPU; CSR has no native column structure. + // Only n_cols counts are copied back for batch planning. std::vector h_col_counts(n_cols, 0); if (total_nnz > 0) { unsigned int* d_col_counts = pool.alloc(n_cols); @@ -547,10 +521,8 @@ static void ovr_sparse_csr_host_streaming_impl( "OVR host CSR column-count D2H"); } - // Each batch sorted in one CUB segmented call (int32 item count); its - // CSR->CSC transpose lives in per-stream scratch (~BYTES_PER_NNZ/nnz). - // Shrink sub_batch_cols until densest window fits BOTH the int32 limit AND - // a per-stream budget slice (tall matrices neither overflow CUB nor OOM). + // Each batch uses one CUB segmented sort and per-stream CSR->CSC scratch. + // Shrink sub_batch_cols until item counts and memory budget both fit. constexpr size_t BYTES_PER_NNZ = sizeof(InT) + sizeof(float) + 2 * sizeof(int) + 8; // buffers + CUB temp size_t batch_nnz_cap = SAFE_BATCH_NNZ; @@ -598,9 +570,8 @@ static void ovr_sparse_csr_host_streaming_impl( per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); } - // Stage data too when indices resident and data + one stream's transpose - // buffers fit (scatter reads values at HBM speed). Else data stays mapped - // zero-copy (bounded for matrices too large to stage). + // Stage data when indices are resident and one transpose stream still fits. + // Otherwise values stay mapped zero-copy for bounded-memory streaming. size_t resident = indices_staged ? idx_bytes : 0; bool data_staged = total_nnz > 0 && indices_staged && resident + data_bytes + per_stream_bytes <= budget; @@ -762,9 +733,7 @@ static void ovr_sparse_csr_host_streaming_impl( sync_streams(streams, "sparse host CSR streaming"); } -// ============================================================================ // Sign-safe sparse OVR path: sparse window -> dense f32 tile -> dense rank. -// ============================================================================ constexpr int SPARSE_DENSE_OVR_CHUNK_COLS = 512; @@ -1239,9 +1208,7 @@ static void ovr_dense_csr_host_streaming_impl( } } -// ============================================================================ -// Sparse-aware CSC OVR streaming (sort only stored nonzeros) -// ============================================================================ +// Sparse-aware CSC OVR streaming: sort only stored nonzeros. template static void ovr_sparse_csc_streaming_impl( @@ -1336,9 +1303,8 @@ static void ovr_sparse_csc_streaming_impl( CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); } - // Sort stored values (keys=data, vals=row_indices). Row indices fit - // int32 (n_rows < 2^31); downcast int64 here so sort + rank stay int32 - // (half the val buffer) -- the device boundary. + // Sort stored values; row indices become int32 sort values here. + // This keeps sort/rank int32 while preserving int64 sparse buffers. if (batch_nnz > 0) { const int* idx_src; if constexpr (sizeof(IndexT) > sizeof(int)) { @@ -1381,17 +1347,8 @@ static void ovr_sparse_csc_streaming_impl( sync_streams(streams, "sparse ovr streaming"); } -// ============================================================================ -// Sparse-aware CSR OVR streaming (partial CSR→CSC transpose per sub-batch) -// ============================================================================ - -/** - * Sparse-aware OVR streaming pipeline for GPU CSR data. - * P0: histogram nnz per column -> per-batch nnz + max_batch_nnz for sizing. - * P1: alloc per-stream buffers sized to max_batch_nnz. - * P2: per sub-batch scatter CSR->CSC (partial atomic transpose) -> CUB sort - * only nonzeros -> sparse rank. Sort work drops ~1/sparsity vs dense. - */ +// Sparse-aware CSR OVR streaming with partial CSR->CSC transpose per batch. +// Histogram plans batches; each batch transposes, sorts nnz only, then ranks. template static void ovr_sparse_csr_streaming_impl( const float* csr_data, const IndexT* csr_indices, const IndptrT* csr_indptr, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh index 9b5e7377..d02103cc 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh @@ -2,12 +2,8 @@ #include -// Walk this thread's chunk [my_start, my_end) of a sorted column, accumulating -// tie-averaged ranks into grp_sums (atomic, strided by acc_stride). Ties that -// straddle a chunk boundary are expanded to their global extent within -// [seg_floor, seg_ceil) by binary search. `rank_offset` shifts every rank (the -// sparse path uses it to account for implicit leading zeros). Returns this -// thread's tie-correction sum (sum of t^3 - t over tie blocks it owns). +// Walk one sorted-column chunk and accumulate tie-averaged ranks atomically. +// Boundary ties are expanded by search; sparse paths pass a rank_offset. template __device__ __forceinline__ double ovr_walk_tie_runs( const float* sv, const IndexT* si, const int* group_codes, double* grp_sums, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 7ecd99ac..a5078997 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -5,19 +5,9 @@ #include "wilcoxon_block_reduce.cuh" #include "wilcoxon_ovr_tie_walk.cuh" -// Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. Ranks -// ONLY stored positives; all zeros (stored + implicit n_rows-nnz) form one -// leading tie block ranked analytically at (total_zero+1)/2, each group's zero -// contribution in closed form -> O(nnz log nnz)/col. Sentinel-group (n_groups) -// rows feed the rest/tie denominator but get no rank-sum accumulation. -// -// CRITICAL: validity relies on upstream rejection of negative sparse values -// (guarantees zeros form the first tie block). use_gmem selects shared- vs -// global-memory accumulators (sparse_ovr_smem_config), REQUIRED for large -// n_groups (Perturb-seq) -- do not remove. -// -// Grid (sb_cols,), Block (tpb,). Shared (doubles): grp_sums[n_groups] + -// grp_nz_count[n_groups] + warp_buf[32]. +// Sparse OVR rank for nonnegative stored values; zeros rank analytically. +// CRITICAL: negative rejection and gmem fallback are required at large +// n_groups. template __global__ void rank_sums_sparse_ovr_kernel( const float* __restrict__ sorted_vals, @@ -141,11 +131,8 @@ __global__ void rank_sums_sparse_ovr_kernel( } } -// Shared sparse-OVR rank launch (all four sparse OVR impls). Optionally zeroes -// the gmem accumulators, then launches the analytic-zero rank kernel. use_gmem -// is the CRITICAL large-n_groups/perturbation fallback (see -// sparse_ovr_smem_config) — DO NOT drop the gmem branch. ValT is the -// sorted-row-index type (int everywhere today). +// Shared sparse-OVR rank launch for all sparse OVR implementations. +// CRITICAL: keep the gmem fallback for large-n_groups perturbation DE. template static inline void launch_ovr_sparse_rank( const float* sorted_vals, const ValT* sorted_row_idx, @@ -167,11 +154,8 @@ static inline void launch_ovr_sparse_rank( CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); } -// CRITICAL — DO NOT REMOVE the gmem branch (large n_groups / perturbation DE). -// smem-vs-gmem for the sparse-OVR stats cast+accumulate kernel. Needs -// n_arrays*n_groups doubles in smem; over the per-block limit, use_gmem=true -// selects ovr_cast_and_accumulate_sparse_global_kernel (accumulates in gmem). -// Load-bearing fallback, not dead. +// CRITICAL: sparse stats gmem fallback is load-bearing for large n_groups. +// It selects the global accumulator when smem would exceed the per-block limit. static size_t cast_accumulate_smem_config(int n_groups, bool compute_nnz, bool compute_totals, bool& use_gmem) { int n_arrays = 1 + (compute_nnz ? 1 : 0); @@ -185,9 +169,8 @@ static size_t cast_accumulate_smem_config(int n_groups, bool compute_nnz, return compute_totals ? WARP_REDUCE_BUF * sizeof(double) : 0; } -// Shared cast+accumulate loop for the two sparse-OVR stats kernels. Casts each -// stored value to f32 and atomically accumulates per-group sums (+nnz), strided -// by acc_stride (1 for per-block smem, sb_cols for the gmem row-major layout). +// Shared cast+accumulate loop for sparse-OVR stats kernels. +// Casts to f32 for sort and atomically accumulates f64 sums/nnz. template __device__ __forceinline__ void accumulate_group_stats( const InT* data_in, float* data_f32_out, const IndexT* indices, @@ -212,13 +195,8 @@ __device__ __forceinline__ void accumulate_group_stats( } } -/** - * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. Sub-batch - * CSC column c lives at [col_seg_offsets[c], col_seg_offsets[c+1]); writes an - * f32 copy for the CUB sort and accumulates per-group sum/nnz in f64 (implicit - * zeros contribute nothing). Block-per-column (grid (sb_cols,), block (tpb,)), - * smem (1+compute_nnz)*n_groups doubles. - */ +/** Pre-sort cast-and-accumulate kernel for sparse OVR streaming. + * Writes f32 sort keys and accumulates explicit-value sums/nnz in f64. */ template __global__ void ovr_cast_and_accumulate_sparse_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, @@ -273,10 +251,8 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( } } -// CRITICAL — DO NOT REMOVE. Gmem variant of the stats accumulator, selected by -// cast_accumulate_smem_config when n_groups is too large for the smem kernel -// (its n_arrays*n_groups double buffer exceeds the per-block limit). Required -// for Perturb-seq-scale n_groups. +// CRITICAL: gmem stats accumulator for n_groups too large for smem. +// Required for Perturb-seq-scale group counts. template __global__ void ovr_cast_and_accumulate_sparse_global_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index e68e02c7..6d39ab60 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -155,10 +155,7 @@ def _init_stats_arrays(self, n_genes: int) -> None: def _basic_stats(self) -> None: """Compute means, vars, and pts for each group. - - If data is already on GPU, uses Aggregate for fast single-pass computation. - Otherwise, sets flag for chunk-based computation during the wilcoxon loop. - """ + Host data defers stats to the Wilcoxon chunk/streaming path.""" n_genes = self.X.shape[1] try: @@ -189,9 +186,8 @@ def _basic_stats(self) -> None: cat_to_idx = {str(name): i for i, name in enumerate(cat_names)} order = [cat_to_idx[str(name)] for name in self.groups_order] - # Aggregate returns stats per ALL categories. Slice to selected groups - # for per-group means/vars; keep the all-category arrays for "rest" - # stats so the totals stay correct when ``groups`` is a strict subset. + # Aggregate returns all categories; slice selected groups for outputs. + # Keep all-category totals so ``groups`` subsets get correct rest stats. sums_all = result["sum"] sq_sums_all = result["sq_sum"] nnz_all = result["count_nonzero"] if self.comp_pts else None @@ -209,9 +205,8 @@ def _basic_stats(self) -> None: else: pts = None - # Compute rest statistics if reference='rest' — "rest" means every - # cell in ``groupby`` not in this group, including cells in - # categories that weren't selected via ``groups=``. + # For reference='rest', rest includes every category not in this group. + # That includes categories omitted by a strict ``groups=`` selection. if self.ireference is None: n_total = agg.n_cells.sum() n_rest = n_total - n @@ -374,17 +369,12 @@ def compute_statistics( **kwds, ) -> None: """Compute statistics for all groups.""" - # The optimized sparse Wilcoxon paths inject implicit zeros analytically - # as a tie at the column minimum (valid only for nonnegative data). - # t-test/logreg are mean/variance/model-based and sign-agnostic. For the - # Wilcoxon methods we canonicalize and, when sparse data holds - # negatives, route to sign-safe dense ranking inside the sparse - # streamers rather than erroring. + # Sparse Wilcoxon handles implicit zeros analytically only for nonnegative data. + # Signed sparse Wilcoxon routes to sign-safe dense ranking inside streamers. self._sparse_negative_fallback = False if method in {"wilcoxon", "wilcoxon_binned"}: - # Canonicalize before the negative check: summing duplicates can - # change stored values (e.g. +a and -a -> 0), and the fast paths - # rank each stored nnz once, so they must see scanpy's summed view. + # Canonicalize before the negative check because summing duplicates can change signs. + # Fast paths rank stored nnz once, so they must see scanpy's summed view. self.X = _canonicalize_sparse(self.X) self._sparse_negative_fallback = _sparse_has_negative(self.X) if method in { diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py b/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py index 1232fe28..d2decc70 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py @@ -24,12 +24,8 @@ def logreg(rg: _RankGenes, **kwds) -> list[tuple[int, NDArray, None]]: X = rg.X[selected, :] codes = rg.group_codes[selected] - # Encode the multinomial class labels in canonical (original category) order - # rather than in `groups_order` order. groups_order echoes the user's - # `groups=` argument (see _select_groups), but cuML's softmax solver is not - # invariant to a class-index permutation, so without this the fitted scores - # would depend on the order groups are listed in. canon_label[i] is the - # class index used for groups_order[i]; coef_ rows are mapped back below. + # Encode multinomial classes in original category order for cuML softmax. + # groups_order follows the user; coef_ rows are mapped back below. cat_order = {str(c): i for i, c in enumerate(rg.labels.cat.categories)} canon_key = np.array([cat_order[str(g)] for g in rg.groups_order]) canon_label = np.empty(n_groups, dtype=np.int64) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index ecd93b10..19bdba49 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -16,15 +16,8 @@ def _sparse_has_negative(X) -> bool: - """Whether an in-memory sparse ``X`` stores an explicit negative value. - - The fast sparse Wilcoxon paths add implicit (structural) zeros as a tie at - the column minimum, which is correct only for nonnegative stored values. A - negative breaks that, so in-memory Wilcoxon routes signed sparse data to - the sign-safe sparse-dense ranker. Dask arrays are not inspected here - (they are neither ``scipy`` nor ``cupy`` sparse); ``wilcoxon_binned`` guards - Dask sparse separately. Dense and t-test/logreg never need this. - """ + """Return whether an in-memory sparse matrix stores a negative value. + Signed sparse Wilcoxon needs the sign-safe sparse-dense ranker.""" if sp.issparse(X) or cpsp.issparse(X): if np.dtype(X.data.dtype).kind == "c": return False @@ -33,13 +26,8 @@ def _sparse_has_negative(X) -> bool: def _canonicalize_sparse(X): - """Sum duplicate entries and sort indices of sparse ``X`` in place. - - The fast Wilcoxon paths rank each stored nonzero once, so non-canonical - input with duplicate ``(row, col)`` entries would diverge from scanpy, - which sums duplicates when it densifies. Canonicalizing keeps them in - agreement. A no-op for already-canonical or dense input. - """ + """Sum duplicates and sort sparse indices in place when needed. + Fast Wilcoxon ranks stored nnz once, so it expects scanpy's summed view.""" if ( (sp.issparse(X) or cpsp.issparse(X)) and getattr(X, "format", None) in {"csr", "csc"} @@ -56,26 +44,8 @@ def _select_groups( reference: str = "rest", skip_empty_groups: bool = False, ) -> tuple[NDArray, NDArray[np.int32], NDArray[np.int64]]: - """Build integer group codes from a categorical Series. - - Parameters - ---------- - labels - Categorical Series (from ``adata.obs[groupby]``). - selected - Group names to keep, or ``None`` for all groups. - Must already include the reference group if applicable. - - Returns - ------- - groups_order - Selected group names as a numpy array. - group_codes - Per-cell int32 codes: ``0..n_groups-1`` for selected cells, - ``n_groups`` (sentinel) for unselected cells. - group_sizes - Number of cells per selected group (int64). - """ + """Build selected group names, per-cell int32 codes, and group sizes. + Unselected cells receive the sentinel code ``n_groups``.""" all_categories = labels.cat.categories if selected is None: @@ -145,13 +115,8 @@ def _choose_chunk_size(requested: int | None) -> int: def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray: - """ - Densify a CSC column window [start, stop) into an F-order float64 block via - the fused ``csc_tile_to_dense`` kernel (column-major, coalesced, no atomics). - - Slices the window by indptr pointers so only that window's nonzeros are - touched (and, for host CSC, transferred). Works for scipy and CuPy CSC. - """ + """Densify a CSC column window into an F-order float64 GPU block. + Slices by indptr so only window nonzeros are touched/transferred.""" from rapids_singlecell._cuda import _rank_stats_cuda as _rs s_ptr = int(X_csc.indptr[start]) @@ -174,12 +139,8 @@ def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray def _csr_tile_to_dense_block(X, start: int, stop: int) -> cp.ndarray: - """Densify a CSR column window [start, stop) straight into an F-order - float64 block via a single fused CSR->dense kernel, skipping the CSR->CSC - tile rebuild that ``X[:, start:stop].tocsc()`` (host) / ``X[:, start:stop]`` - (device) would do. For device CSR the index arrays are already on the GPU, - so there is no transfer. - """ + """Densify a CSR column window into an F-order float64 GPU block. + Device CSR avoids rebuilding a CSR/CSC slice before densifying.""" from rapids_singlecell._cuda import _rank_stats_cuda as _rs n_rows = X.shape[0] @@ -201,13 +162,8 @@ def _csr_tile_to_dense_block(X, start: int, stop: int) -> cp.ndarray: def _get_column_block(X, start: int, stop: int) -> cp.ndarray: """Extract a column block as a dense F-order float64 CuPy array.""" match X: - # Device CSR: the fused csr_tile_to_dense kernel densifies the window in - # one pass with no transfer (index arrays are already on the GPU) -- the - # big win. Host CSR is intentionally NOT routed here: doing so would - # re-transfer the whole CSR every chunk (only ~1.15x and worse with more - # chunks); host data should be moved to the device once upstream - # (`X_to_GPU`) so it lands in this fast device branch, otherwise it falls - # through to the `.tocsc()` path below. + # Device CSR can densify in one pass without transfer. + # Host CSR intentionally falls through to avoid per-chunk full transfers. case cpsp.csr_matrix(): return _csr_tile_to_dense_block(X, start, stop) case sp.csc_matrix() | sp.csc_array(): diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 23f15ca0..a06c6d40 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -392,20 +392,8 @@ def _validate_wilcoxon_sparse_dtype(X) -> None: def _device_sparse_arrays(X): - """Prepare device-sparse arrays for the Wilcoxon kernels. - - Wilcoxon ranking sorts float32 keys on every sparse device path, including - the sign-safe sparse-dense OVR path. Casting ``X.data`` to float32 here - therefore does not diverge from any float64 ranking path, because there is - none. This only loses precision when preprocessing ran in float64; - float32-preprocessed values (even if later stored as float64) are - float32-exact, so ranking matches scanpy bit-for-bit (~1e-13). For a fully - float64 pipeline the rank-derived scores/p-values match scanpy-on-float64 - to ~1e-4 on log-normalized data (below any significance threshold, no DE - calls change), while means and log fold changes are still computed in - float64. See the ``rank_genes_groups`` note on ranking precision. float64 - input is accepted to spare the caller a pre-cast. - """ + """Prepare device-sparse arrays for float32-key Wilcoxon kernels. + float64 data is accepted and cast for ranking; stats stay float64.""" data_dtype = np.dtype(X.data.dtype) if data_dtype == np.float32: data = X.data @@ -444,9 +432,8 @@ def wilcoxon( return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" - # Host dense OVR and OVO stream column windows from host. Already-device - # dense OVO still uses the device-resident tiered planner. - # Aggregate if on GPU, else defer to chunks. + # Host dense streams column windows; device dense stays device-resident. + # Aggregate stats on GPU, otherwise compute them inside streaming paths. X = rg.X _validate_wilcoxon_sparse_dtype(X) rg._basic_stats() diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index 97308d4f..4e1236ad 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -25,11 +25,7 @@ def _fill_sparse_zero_bin(hist: cp.ndarray, group_counts: cp.ndarray) -> None: - """Fill bin 0 with zero counts for sparse histograms (in-place). - - Sparse kernels only populate bins 1..n_bins (nonzero values). - Bin 0 = group_size - sum(bins 1..n_bins) for each gene/group. - """ + """Fill sparse histogram bin 0 from group size minus nonzero-bin counts.""" nonzero_per_group = hist.sum(axis=2) # (n_genes, n_groups) hist[:, :, 0] = group_counts[None, :].astype(cp.uint32) - nonzero_per_group @@ -73,39 +69,7 @@ def wilcoxon_binned( chunk_size: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, ) -> list[tuple[int, NDArray, NDArray]]: - """Histogram-based approximate Wilcoxon rank-sum test. - - Approximates ranks by discretizing expression values into ``n_bins`` - fixed-width bins, then computing rank sums from cumulative histogram - counts. This avoids the O(n log n) per-gene sort required by exact - Wilcoxon, making it feasible for datasets with millions of cells and - compatible with Dask arrays. - - Supports both one-vs-rest (``reference='rest'``) and one-vs-one - (``reference=''``) comparisons. - - Parameters - ---------- - rg - The _RankGenes instance. - tie_correct - Adjust the variance for ties. In the binned approach each bin - acts as a tie group, so the correction uses the bin counts - directly. - n_bins - Number of histogram bins. Higher = better approximation. - Default is 1000 for in-memory arrays and 200 for Dask arrays. - chunk_size - Genes processed per GPU batch. Controls peak GPU memory. - bin_range - How to determine the histogram bin range. - ``None`` (default) uses ``'auto'`` for in-memory arrays and - ``'log1p'`` for Dask arrays (to avoid a costly data scan). - ``'log1p'`` uses a fixed [0, 15] range suitable for - log1p-normalized data. - ``'auto'`` computes the actual (min, max) of the data, spanning - negatives when present. Use it for data outside the fixed log1p range. - """ + """Histogram-based approximate Wilcoxon rank-sum test.""" if not rg.is_log1p: warnings.warn( "wilcoxon_binned expects log-normalized data " @@ -125,11 +89,8 @@ def wilcoxon_binned( n_cells, n_genes = X.shape group_sizes = rg.group_sizes - # Dask sparse cannot bin negatives correctly: the sparse histogram puts - # implicit zeros in the lowest bin and _data_range floors the range at 0 - # for Dask sparse, so negatives would be silently mis-ranked. Refuse rather - # than return wrong numbers (in-memory sparse negatives use the dense - # fallback; see _sparse_has_negative). + # Dask sparse cannot bin negatives correctly because implicit zeros use bin 0. + # Refuse instead of silently mis-ranking; in-memory sparse uses dense fallback. if isinstance(X, DaskArray) and cpsp.issparse(X._meta): def _block_data_min(block): @@ -154,18 +115,13 @@ def _block_data_min(block): "zeros. Densify the data or use a nonnegative representation." ) - # group_codes: 0..n_groups-1 for selected cells, n_groups (sentinel) - # for unselected. For vs-rest, unselected cells are binned into a - # dummy group so they contribute to total counts for correct midranks. - # For vs-reference, the kernel bounds guard (grp >= n_groups) skips them. + # group_codes use n_groups as sentinel for unselected cells. + # vs-rest bins sentinels for totals; vs-reference kernels skip them. group_codes_np = rg.group_codes has_unselected = bool(np.any(group_codes_np == n_groups)) - # For one-vs-one with a group subset, only the selected groups' cells - # matter for pairwise rankings. Filter X down so kernels don't iterate - # over irrelevant cells. For Dask we can't cheaply subset rows, but - # the kernel bounds guard (grp >= n_groups → skip) avoids wasted - # atomicAdds, so we just clear the flag without allocating a dummy group. + # One-vs-one only ranks selected groups; filter in-memory rows. + # Dask keeps rows but kernels skip sentinels, avoiding dummy-group atomics. if ireference is not None and has_unselected: if isinstance(X, DaskArray): has_unselected = False @@ -212,10 +168,8 @@ def _block_data_min(block): if bin_range is None: bin_range = "log1p" if isinstance(X, DaskArray) else "auto" - # The fixed log1p [0, 15] range assumes nonnegative data. For signed sparse - # input the dense fallback would clamp negatives into the lowest bin and - # silently produce wrong rank sums, so switch to the data-driven 'auto' - # range (which spans the true [min, max], including negatives). + # The fixed log1p range assumes nonnegative data. + # Signed sparse fallback needs data-driven auto range to avoid clamping. if rg._sparse_negative_fallback and bin_range == "log1p": warnings.warn( "bin_range='log1p' is invalid for sparse input with negative values " @@ -309,10 +263,8 @@ def process_gene_batch( is_sparse = False if force_dense and cpsp.issparse(X): - # Negative-values fallback: the sparse histogram assigns implicit zeros - # to bin 0, which is correct only for nonnegative data. Densify the - # column window (chunked, no full materialization) and use the dense - # histogram, whose bins span the full [min, max] range. + # Negative sparse fallback: bin 0 is only correct for nonnegative data. + # Densify the column window so dense bins span the full [min, max]. hist = _launch_dense( _get_column_block(X, start, stop), group_codes, @@ -448,12 +400,7 @@ def _compute_stats_vs_ref( tie_correct: bool = False, use_continuity: bool = False, ) -> tuple[cp.ndarray, cp.ndarray]: - """Compute Wilcoxon z-scores for each group vs a specific reference. - - For each group *g*, midranks are derived from the pairwise histogram - ``hist_g + hist_ref`` so that only cells in the compared pair - contribute to the ranking. - """ + """Compute Wilcoxon z-scores for each group vs a specific reference.""" # hist shape: (n_genes, n_groups, n_bins_total) ref_hist = hist[:, ireference : ireference + 1, :] # (n_genes, 1, n_bins_total) @@ -601,13 +548,7 @@ def _process_dask( inv_bin_width: float, n_bins_total: int, ) -> cp.ndarray: - """Build histogram from a Dask array. - - Receives the full (unsliced) Dask array and column range - ``[start, stop)``. Column selection happens inside each block - handler on the materialised CuPy chunk, keeping the Dask graph - simple (no column-slice node per gene batch). - """ + """Build a column-range histogram from an unsliced Dask array.""" import dask.array as da if cpsp.isspmatrix_csr(X._meta): diff --git a/tests/test_rank_genes_groups_ttest.py b/tests/test_rank_genes_groups_ttest.py index b076c0be..03500899 100644 --- a/tests/test_rank_genes_groups_ttest.py +++ b/tests/test_rank_genes_groups_ttest.py @@ -360,11 +360,7 @@ def test_rank_genes_groups_ttest_pts(reference, method): def test_rank_genes_groups_ttest_direct_scipy(): - """Test t-test scores directly against scipy.stats.ttest_ind on two matrices. - - Creates a simple two-group dataset and compares rapids_singlecell t-test - directly against scipy.stats.ttest_ind without intermediate statistics. - """ + """Compare rapids_singlecell t-test scores directly to scipy.stats.ttest_ind.""" np.random.seed(42) n_group1, n_group2, n_genes = 50, 60, 20 @@ -410,12 +406,7 @@ def test_rank_genes_groups_ttest_direct_scipy(): def test_rank_genes_groups_ttest_matches_scipy(): - """Test that t-test scores match scipy computation directly. - - This test verifies that our variance clipping fix produces correct results - by comparing against scipy.stats.ttest_ind_from_stats with properly computed - (non-negative) variances. Uses real pbmc68k_reduced dataset at float64 precision. - """ + """Compare t-test scores to scipy stats with nonnegative variances.""" adata = pbmc68k_reduced() # Convert to float64 for maximum precision in comparison adata.X = adata.X.astype(np.float64) diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 6aba2740..9a166611 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -33,14 +33,8 @@ def _make_nonnegative(adata): return adata -# The optimized sparse Wilcoxon paths inject implicit zeros analytically as a tie -# at the column minimum, which is valid only for nonnegative data. With negatives -# present they must NOT be used; instead the ranking falls back to the dense -# full-sort path (correct for any sign), so the result matches running the same -# method on the dense matrix. (t-test/t-test_overestim_var/logreg never need this -# and accept signed sparse data directly -- e.g. mixscape's LDA t-test.) -# (method, reference) combos: vs-rest for wilcoxon + binned, plus the OVO -# (with-reference) wilcoxon path. binned has no with-reference mode. +# Sparse Wilcoxon negative values must fall back to dense full-sort ranking. +# Covers Wilcoxon OVR/OVO and binned OVR; other methods accept signed sparse. @pytest.mark.parametrize( ("method", "reference"), [("wilcoxon", "rest"), ("wilcoxon_binned", "rest"), ("wilcoxon", "b")], @@ -84,9 +78,8 @@ def test_rank_genes_groups_sparse_negative_values_fallback(method, reference, fm @pytest.mark.parametrize("layout", ["csr", "csc"]) @pytest.mark.parametrize("reference", ["rest", "1"]) def test_device_sparse_int64_indptr_matches_scanpy(layout, reference): - # Real int64 indptr only occurs at nnz > 2^31 (unallocatable in CI). cupy - # >= 14.1 preserves explicitly promoted int64 indices/indptr, so a small - # matrix promoted to int64 drives the int64 device overloads. + # Real int64 indptr needs nnz > 2^31, so CI promotes a small matrix. + # cupy >= 14.1 preserves the promoted int64 buffers for overload coverage. rng = np.random.default_rng(0) dense = np.abs(rng.standard_normal((150, 8))).astype(np.float32) dense[dense < 0.5] = 0.0 @@ -726,10 +719,8 @@ def _make_sized_groups_adata(group_sizes, n_genes, seed=0): return adata -# OVO tiers (wilcoxon_fast_common.cuh): MEDIUM<=512, LARGE(fused smem sort)<=2500, -# HUGE(CUB segmented sort)>2500. Group sizes in the standard blobs datasets are -# <=~70 (all MEDIUM), so LARGE/HUGE are otherwise never exercised. These force a -# single large test group. +# OVO tier coverage: standard blobs hit only MEDIUM. +# These cases force LARGE fused-smem sort and HUGE CUB segmented sort. @pytest.mark.parametrize( "fmt", ["numpy_dense", "cupy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"], @@ -766,14 +757,8 @@ def test_wilcoxon_ovo_large_group_tiers_match_scanpy(fmt, tie_correct, big): ) -# n_groups > ~3056 makes the per-block smem for the sparse-OVR accumulator -# ((2*n_groups+32) doubles) exceed the 48KB static limit, so sparse_ovr_smem_config -# (and the dense ovr_smem_config) fall back to the global-memory accumulator. -# This is the perturbation regime (thousands of guides vs rest). scanpy's -# 3000+-group DataFrame build is O(n_groups^2) and too slow for an in-suite -# parity check; gmem-vs-scanpy parity is verified out-of-band (<=2e-15). Here we -# guard that every storage format (incl. the dense reference kernel) agrees at -# gmem scale, with and without tie correction. +# Many groups force global-memory accumulators, matching perturbation-scale DE. +# scanpy is too slow here, so this guards cross-format agreement at gmem scale. @pytest.mark.parametrize("tie_correct", [False, True]) def test_wilcoxon_ovr_many_groups_gmem_formats_agree(tie_correct): adata = _make_sized_groups_adata([26] * 3100, n_genes=6, seed=3) @@ -806,15 +791,8 @@ def test_wilcoxon_ovr_many_groups_gmem_formats_agree(tie_correct): ) -# Regression guard for the host-dense OVR streaming launcher: in gmem rank mode -# the rank kernel atomicAdds onto its per-stream rank-sum buffer without -# self-zeroing, and those buffers are reused round-robin -- so the launcher must -# zero them per sub-batch. This only bites past the DENSE gmem flip -# (n_groups > ~6112) AND with enough genes (> N_STREAMS*sub_batch_cols = 256) -# that the per-stream buffers actually wrap; the formats-agree gmem test above -# uses n_genes=6 (one batch, fresh buffer) and cannot see it. Compares the host -# (numpy) dense path against the device (cupy) dense + sparse paths, which zero -# the gmem buffer correctly. +# Host-dense OVR gmem buffers are reused round-robin and must be zeroed per batch. +# This forces enough groups and genes to wrap per-stream rank-sum buffers. @pytest.mark.filterwarnings("ignore::RuntimeWarning") # 6200 tiny groups warn def test_wilcoxon_ovr_dense_gmem_host_streaming_buffer_reuse(): adata = _make_sized_groups_adata([2] * 6200, n_genes=400, seed=7) @@ -846,9 +824,8 @@ def test_wilcoxon_ovr_dense_gmem_host_streaming_buffer_reuse(): ) -# Host-dense OVR has only float32/float64 nanobind overloads; integer/bool/uint/ -# float16 numpy must be cast to float32 (mirrors the sparse path) rather than -# raising a TypeError. +# Host-dense OVR has only float32/float64 nanobind overloads. +# Other numpy numeric dtypes must cast to float32 rather than raise. @pytest.mark.parametrize( "data_dtype", [np.int32, np.int64, np.uint16, np.float16, bool] ) @@ -887,12 +864,8 @@ def run(arr): ) -# F-contiguous host-dense numpy hits the F-order nanobind overload of the host -# streaming launcher: float32 -> the reinterpret-cast fast path (no cast kernel), -# float64 -> dense_block_to_f32_kernel's identity branch. Every numpy_dense -# fixture elsewhere is C-order, so this is the only coverage of that overload. -# AnnData preserves F-order, so an F-contiguous X reaches the path; result must -# match the C-order run on identical data. +# F-contiguous host-dense numpy hits the F-order host-streaming overload. +# It must match the C-order run on identical data. @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_wilcoxon_ovr_fortran_order_host_dense_matches_c_order(dtype): rng = np.random.default_rng(11) @@ -927,16 +900,8 @@ def run(arr): ) -# Regression guard for a shared-memory OOB write in the host sparse OVR -# cast-and-accumulate kernel: it placed the per-group nnz accumulator at a fixed -# 2*n_groups smem offset, but cast_accumulate_smem_config packs only the enabled -# arrays -- and the host OVR path runs with sq-sums OFF, nnz ON (pts=True). The -# overrun was benign at tiny n_groups (it landed in rounded smem slack, and the -# write/read used the same wrong offset so values stayed self-consistent) but -# caused an illegal memory access once n_groups grew past ~25. n_groups=50 + -# pts=True is the faulting regime, with the smem (non-gmem) accumulator still -# selected. Covers both host sparse formats (the ones that crashed) plus the -# dense/device formats for full parity. +# Guards host sparse OVR smem packing for pts=True, where nnz offset once overran. +# n_groups=50 stays on smem but reaches the formerly faulting regime. @pytest.mark.parametrize( "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] ) @@ -976,11 +941,8 @@ def test_wilcoxon_ovr_pts_many_groups_match_scanpy(fmt): ) -# Companion to test_wilcoxon_ovr_many_groups_gmem_formats_agree, with pts=True: -# at gmem scale (n_groups > ~3056) the global cast-accumulate and the -# analytic-zero rank kernel both drive the per-group nnz path. scanpy's -# 3000+-group build is too slow for an in-suite parity check, so we assert every -# storage format agrees, including the pts fraction-expressing matrix. +# Companion gmem-scale check with pts=True. +# It exercises global cast-accumulate and analytic-zero nnz paths. def test_wilcoxon_ovr_many_groups_gmem_pts_formats_agree(): adata = _make_sized_groups_adata([26] * 3100, n_genes=6, seed=5) ref = None @@ -1220,9 +1182,7 @@ def test_rank_genes_groups_wilcoxon_pts(reference): ) -# ============================================================================ -# Ground-truth validation against scipy.stats.mannwhitneyu -# ============================================================================ +# Ground-truth validation against scipy.stats.mannwhitneyu. def _make_perturbation_adata( @@ -1539,10 +1499,7 @@ def test_wilcoxon_group_subset_column_order_matches_scanpy(reference): def test_wilcoxon_host_sparse_negative_chunked_stats_match_scanpy(): - """Host scipy-sparse with negatives takes the dense fallback, whose group - means/vars/pts run through the group_chunk_stats kernel (multi-chunk). Those - means (-> logfoldchanges) and pts must match scanpy. - """ + """Host sparse negatives fallback must match scanpy stats across chunks.""" rng = np.random.default_rng(0) n_obs, n_vars = 200, 24 X = (rng.random((n_obs, n_vars)) * 5.0).astype(np.float64) @@ -1591,11 +1548,7 @@ def test_wilcoxon_host_sparse_negative_chunked_stats_match_scanpy(): def test_wilcoxon_fdr_ties_nan_match_scanpy(): - """BH FDR must match scanpy on heavily-tied / constant / all-zero genes, - locking in that the GPU argsort tie-break is inert for adjusted p-values. - Integer data is float32-exact, so ranking is bit-identical to scanpy and the - comparison isolates the FDR step. - """ + """BH FDR must match scanpy on tied, constant, and all-zero genes.""" rng = np.random.default_rng(1) n_obs, n_vars = 240, 30 X = rng.integers(0, 3, size=(n_obs, n_vars)).astype(np.float64) # heavy ties @@ -1625,12 +1578,7 @@ def test_wilcoxon_fdr_ties_nan_match_scanpy(): def _promote_host_index_dtype(X): - """Copy a host scipy CSR/CSC matrix with promoted index-array dtypes. - - scipy couples indptr/indices to one dtype via get_index_dtype. Real int64 - index buffers only occur at nnz > 2^31 in practice, so tests promote a small - matrix explicitly to drive the int64 templates. - """ + """Copy a host scipy CSR/CSC matrix with promoted index-array dtypes.""" X = X.copy() X.indptr = X.indptr.astype(np.int64) X.indices = X.indices.astype(np.int64) @@ -1648,13 +1596,7 @@ def _promote_host_index_dtype(X): ], ) def test_host_sparse_int64_templates_match_int32(reference, layout, data_dtype): - """Exercise the host-sparse int64-indptr / int64-indices kernel templates - (the int64-index/indptr overloads the suite otherwise never reaches). - These differ from the validated int32 host path - only in index dtype, so they must be bit-identical to it. Real int64 indices - only occur at nnz > 2^31 (unallocatable in CI), so we promote a small - matrix's index arrays explicitly and keep it host-resident (scipy sparse + - method='wilcoxon' is not moved to GPU).""" + """Host sparse int64 index templates must match the int32 path bit-for-bit.""" rng = np.random.default_rng(0) dense = (rng.random((150, 8)) * 4.0).astype(np.float64) dense[dense < 1.5] = 0.0 # nonnegative + structural zeros -> sparse fast path @@ -1689,13 +1631,7 @@ def test_host_sparse_int64_templates_match_int32(reference, layout, data_dtype): def _anndata_with_group_sizes(sizes, *, n_genes=6, seed=0): - """Dense AnnData whose per-group cell counts are exactly ``sizes``. - - The OVO tier dispatch picks the rank kernel by *test-group* size - (MEDIUM<=512, LARGE 513-2500, HUGE>2500), so engineered group sizes drive - specific bands. Integer data is float32-exact, so ranking is bit-identical - to scanpy. - """ + """Dense AnnData with exact per-group sizes for OVO tier tests.""" rng = np.random.default_rng(seed) labels = [] for name, n in sizes.items(): @@ -1735,22 +1671,14 @@ def _assert_ovo_matches_scanpy(adata, reference): ], ) def test_ovo_tier_bands_match_scanpy(sizes, seed): - """OVO dense-tiered path across MEDIUM/LARGE (groups <= 512 plus a 1000-cell - LARGE) and HUGE (a > 2500-cell group, CUB segmented sort); match scanpy.""" + """OVO dense-tiered MEDIUM/LARGE/HUGE paths must match scanpy.""" adata = _anndata_with_group_sizes(sizes, seed=seed) _assert_ovo_matches_scanpy(adata, reference="ref") @pytest.mark.filterwarnings("ignore::RuntimeWarning") # 6200 tiny groups warn def test_ovr_dense_gmem_branch_matches_scipy(): - """The DENSE OVR global-memory accumulator (use_gmem) engages only when the - per-block group accumulators exceed the 48 KB MaxSharedMemoryPerBlock limit - -- n_groups > 6112 (= 49152 / 8 - 32). No other test reaches it (the - >3056-group gmem tests only flip the *sparse* accumulator). n_groups=6200 - deterministically routes through dense gmem; a scanpy oracle here costs ~30s - (its per-group Python loop), so we validate a sample of groups against scipy - mannwhitneyu with rsc's exact settings (tie-corrected asymptotic, no - continuity).""" + """Dense OVR gmem branch must match scipy on sampled groups.""" from scipy.stats import mannwhitneyu n_groups, n_genes = 6200, 4 # > 6112 -> dense gmem accumulator @@ -1787,9 +1715,7 @@ def test_ovr_dense_gmem_branch_matches_scipy(): def test_skip_empty_groups_vs_rest_drops_singleton(): - """skip_empty_groups=True with reference='rest' silently drops <2-cell - groups (covers the reference=='rest' branch of _select_groups, which the - existing reference='ref' skip tests miss).""" + """skip_empty_groups=True with reference='rest' drops singleton groups.""" adata = _anndata_with_group_sizes({"a": 10, "b": 10, "c": 1}, seed=4) rsc.tl.rank_genes_groups( adata, "group", method="wilcoxon", use_raw=False, skip_empty_groups=True @@ -1825,11 +1751,7 @@ def test_skip_empty_groups_none_remain_raises(): "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] ) def test_ovr_tie_correct_false_tie_heavy_matches_scanpy(fmt): - """OVR (reference='rest') tie_correct=False on TIE-HEAVY data must match - scanpy across every storage format. The pre-existing tie_correct=False oracle - uses tie-free blobs (tie_corr ~= 1), so a wrong uncorrected variance would - pass; integer data with many ties stresses the omitted tie term on each path - (dense, host-sparse CSR/CSC, device-sparse CSR/CSC).""" + """OVR tie_correct=False on tie-heavy data must match scanpy for all formats.""" rng = np.random.default_rng(7) n_obs, n_genes = 180, 8 dense = rng.integers(0, 5, size=(n_obs, n_genes)).astype(np.float64) # ties @@ -1863,15 +1785,7 @@ def test_ovr_tie_correct_false_tie_heavy_matches_scanpy(fmt): ) @pytest.mark.parametrize("reference", ["rest", "1"]) # OVR and OVO epilogues def test_use_continuity_matches_scipy(fmt, reference): - """use_continuity=True is validated only on dense-OVO elsewhere. Check the - continuity epilogue composes correctly with each path's rank_sums (OVR + OVO, - every format) vs scipy.mannwhitneyu(use_continuity=True, asymptotic). - - Groups OVERLAP (no separation) on purpose: that keeps |R-E[R]| moderate so - the 0.5 continuity term MATERIALLY changes p -- a missing continuity - correction would then fail the scipy oracle (non-vacuous). Because U and E[U] - are multiples of 0.5, rsc's clamp (max(|d|-0.5,0)) and scipy's shift agree - exactly. tie_correct=True matches scipy's always-on asymptotic tie term.""" + """Continuity epilogues must match scipy across OVR/OVO and formats.""" from scipy.stats import mannwhitneyu rng = np.random.default_rng(8) @@ -1912,9 +1826,7 @@ def test_use_continuity_matches_scipy(fmt, reference): ) -# --------------------------------------------------------------------------- -# Entry-point / init validation (rank_genes_groups + _RankGenes + _select_groups) -# --------------------------------------------------------------------------- +# Entry-point / init validation (rank_genes_groups + _RankGenes + _select_groups). def test_rank_genes_groups_default_method_is_ttest(): @@ -1987,9 +1899,7 @@ def test_singleton_group_without_skip_raises(): @pytest.mark.parametrize("use_raw", [None, True]) def test_rank_genes_groups_reads_raw_matches_scanpy(use_raw): - """use_raw=None (raw present) and use_raw=True both read adata.raw. X is - overwritten with rank-scrambling noise so a path that wrongly read .X would - diverge from scanpy (non-vacuous).""" + """use_raw=None and use_raw=True both read adata.raw, matching scanpy.""" adata = _anndata_with_group_sizes({"0": 30, "1": 30, "2": 30}, seed=6) adata.raw = adata.copy() # raw holds the real signal rng = np.random.default_rng(99) @@ -2039,14 +1949,11 @@ def test_log1p_base_logfoldchanges_match_scanpy(reference, fmt): ) -# --------------------------------------------------------------------------- -# OVO / OVR parity & dispatch gaps -# --------------------------------------------------------------------------- +# OVO / OVR parity and dispatch gaps. def test_ovo_dense_fallback_pts_match_scanpy(): - """OVO sparse-negative dense fallback computes pts via _fill_ovo_chunk_stats - (ref + group branches). Validate pts vs scanpy on the dense equivalent.""" + """OVO sparse-negative dense fallback pts must match scanpy.""" rng = np.random.default_rng(11) dense = (rng.random((120, 8)) * 5.0).astype(np.float64) dense[dense < 1.5] = 0.0 @@ -2078,9 +1985,7 @@ def test_ovo_dense_fallback_pts_match_scanpy(): @pytest.mark.parametrize("fmt", ["numpy_dense", "cupy_csr"]) # CPU + GPU FDR epilogues def test_bonferroni_matches_scanpy(fmt): - """Bonferroni correction (CPU _core.py:584 via dense, GPU :630-631 via the - cupy OVO result path) must match scanpy, not just be <=1 (the prior tests - only asserted the tautological clamp).""" + """Bonferroni correction must match scanpy, not just clamp below one.""" rng = np.random.default_rng(12) dense = rng.integers(0, 5, size=(150, 6)).astype(np.float64) dense[dense < 1.0] = 0.0 diff --git a/tests/test_rank_genes_groups_wilcoxon_binned.py b/tests/test_rank_genes_groups_wilcoxon_binned.py index fc05af15..6922d7ce 100644 --- a/tests/test_rank_genes_groups_wilcoxon_binned.py +++ b/tests/test_rank_genes_groups_wilcoxon_binned.py @@ -429,17 +429,7 @@ def test_sparse_with_actual_zeros(self, adata_blobs): assert np.all(pvals <= 1) def test_sparse_negative_values_fallback(self, adata_blobs): - """Sparse input with negatives must densify: the sparse histogram puts - implicit zeros in bin 0 (valid only for nonnegative data). A *correct* - fallback (densify) matches the dense run; a removed fallback would bin - the implicit zeros below stored negatives and diverge -- so this - assertion fails without the fallback. - - Sensitivity hinges on columns holding BOTH structural zeros AND a value - below them (a negative). Where the zeros are the column minimum, moving - them to bin 0 leaves their rank order unchanged and the binned z is - invariant (which is why a naive sparse-vs-dense check is vacuous). - """ + """Sparse negatives must densify so implicit zeros rank correctly.""" import cupy as cp import cupyx.scipy.sparse as cpsp @@ -570,12 +560,7 @@ def test_top_genes_match_scipy(adata_blobs): @pytest.mark.parametrize("reference", ["rest", "1"]) def test_binned_bin_exact_matches_scipy(reference): - """wilcoxon_binned otherwise has NO external numeric oracle. With integer - data and n_bins >> value-range, each value gets its own bin -> binned ranks - == exact ranks -> binned pvals must match scipy.mannwhitneyu exactly - (tie_correct=True matches scipy's always-on asymptotic tie term). Covers - vs-rest and vs-ref, tie_correct and use_continuity, with non-vacuity - self-guards (each flag must materially change the result).""" + """Bin-exact integer data must match scipy.mannwhitneyu.""" import pandas as pd from scipy.stats import mannwhitneyu @@ -657,9 +642,7 @@ def test_binned_all_zero_sparse_finite(adata_blobs): def test_binned_log1p_invalid_for_negative_sparse_coerces_to_auto(adata_blobs): - """Sparse input with negatives + bin_range='log1p' warns and coerces to - 'auto' (the fixed [0,15] range would clamp negatives). Result must equal the - explicit 'auto' run (non-vacuous: no coercion -> mis-binned negatives differ).""" + """Negative sparse log1p range must warn, coerce to auto, and match auto.""" import cupy as cp import cupyx.scipy.sparse as cpsp