Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 308 additions & 9 deletions src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,23 @@ __global__ void ovo_rank_dense_vs_ref_kernel(
// LARGE/HUGE rank kernel; LARGE smem-sorts, HUGE reads CUB-sorted groups.
// Post-sort mid-rank/tie body is shared and each group owns its output row.
template <bool SMEM_SORT>
__global__ void ovo_rank_sorted_kernel(
const float* __restrict__ ref_sorted, const float* __restrict__ grp_in,
const int* __restrict__ grp_offsets,
const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums,
double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols,
int n_groups, bool compute_tie_corr, int large_padded, int skip_n_grp_le) {
__global__ void ovo_rank_sorted_kernel(const float* __restrict__ ref_sorted,
const float* __restrict__ grp_in,
const int* __restrict__ grp_offsets,
const double* __restrict__ ref_tie_sums,
double* __restrict__ rank_sums,
double* __restrict__ tie_corr, int n_ref,
int n_all_grp, int n_cols, int n_groups,
bool compute_tie_corr, int large_padded,
int skip_n_grp_le, int skip_n_grp_gt) {
int col = blockIdx.x;
int grp = blockIdx.y;
if (col >= n_cols || grp >= n_groups) return;

int g_start = grp_offsets[grp];
int n_grp = grp_offsets[grp + 1] - g_start;
if (n_grp <= skip_n_grp_le) return;
if (n_grp > skip_n_grp_gt) return;
if (n_grp == 0) {
if (threadIdx.x == 0) {
rank_sums[grp * n_cols + col] = 0.0;
Expand Down Expand Up @@ -214,6 +218,214 @@ __global__ void ovo_rank_sorted_kernel(
&tie_corr[grp * n_cols + col]);
}

// LARGE analytic-zero path for nonnegative sparse data.
// Sort only stored positives; zeros are handled from counts.
__global__ void ovo_rank_smem_analytic_kernel(
const float* __restrict__ ref_sorted, const float* __restrict__ grp_in,
const int* __restrict__ grp_offsets,
const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums,
double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols,
int n_groups, bool compute_tie_corr, int skip_n_grp_le, int skip_n_grp_gt) {
int col = blockIdx.x;
int grp = blockIdx.y;
if (col >= n_cols || grp >= n_groups) return;

int g_start = grp_offsets[grp];
int n_grp = grp_offsets[grp + 1] - g_start;
if (n_grp <= skip_n_grp_le) return;
if (n_grp > skip_n_grp_gt) return;
if (n_grp == 0) {
if (threadIdx.x == 0) {
rank_sums[grp * n_cols + col] = 0.0;
if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0;
}
return;
}

const float* ref_col = ref_sorted + (long long)col * n_ref;
const float* grp_col = grp_in + (long long)col * n_all_grp + g_start;

extern __shared__ float grp_smem[];
__shared__ double warp_buf[WARP_REDUCE_BUF];
__shared__ int sh_nnz;
__shared__ int sh_ref_zeros;
if (threadIdx.x == 0) {
sh_nnz = 0;
sh_ref_zeros = sorted_upper_bound(ref_col, 0, n_ref, 0.0f);
}
__syncthreads();

for (int i = threadIdx.x; i < n_grp; i += blockDim.x) {
float v = grp_col[i];
if (v > 0.0f) grp_smem[atomicAdd(&sh_nnz, 1)] = v;
}
__syncthreads();
int nnz = sh_nnz;
int ref_zeros = sh_ref_zeros;
int n_grp_zero = n_grp - nnz;
int total_zero = ref_zeros + n_grp_zero;

// Pad positives to a power of two for bitonic_sort_smem.
int padded = 1;
while (padded < nnz) padded <<= 1;
for (int i = nnz + threadIdx.x; i < padded; i += blockDim.x)
grp_smem[i] = __int_as_float(0x7f800000);
__syncthreads();
if (nnz > 1) bitonic_sort_smem(grp_smem, padded);
__syncthreads();

// Positive ranks are shifted by group zeros, which sort before positives.
double zero_rank =
(total_zero > 0) ? ((double)total_zero + 1.0) / 2.0 : 0.0;
double local_sum =
(threadIdx.x == 0) ? (double)n_grp_zero * zero_rank : 0.0;
int ref_lb = 0, ref_ub = 0, grp_lb = 0, grp_ub = 0;
for (int i = threadIdx.x; i < nnz; i += blockDim.x) {
OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_smem, nnz, grp_smem[i],
ref_lb, ref_ub, grp_lb, grp_ub);
local_sum += r.mid_rank + (double)n_grp_zero;
}
double total = wilcoxon_block_sum(local_sum, warp_buf);
if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total;

if (!compute_tie_corr) return;
__syncthreads();

// Add nonzero tie deltas; zero ties are added below via T(ref+grp)-T(ref).
double local_tie = 0.0;
for (int i = threadIdx.x; i < nnz; i += blockDim.x) {
if (i == 0 || grp_smem[i] != grp_smem[i - 1]) {
float v = grp_smem[i];
int gub = sorted_upper_bound(grp_smem, i + 1, nnz, v);
double cg = (double)(gub - i);
int rlo = sorted_lower_bound(ref_col, 0, n_ref, v);
int rub = sorted_upper_bound(ref_col, rlo, n_ref, v);
double cr = (double)(rub - rlo);
double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0;
local_tie += group_tie;
if (cr > 0.0) {
double comb = cr + cg;
double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0;
local_tie += comb * comb * comb - comb - ref_tie - group_tie;
}
}
}
double tie = wilcoxon_block_sum(local_tie, warp_buf);
if (threadIdx.x == 0) {
double zd = 0.0;
if (total_zero > 1)
zd += (double)total_zero * total_zero * total_zero - total_zero;
if (ref_zeros > 1)
zd -= (double)ref_zeros * ref_zeros * ref_zeros - ref_zeros;
tie_corr[grp * n_cols + col] =
finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie + zd);
}
}

// Compact HUGE-band positives and emit [base, base + nnz) segments.
__global__ void compact_huge_nonzeros_kernel(
const float* __restrict__ grp_dense, const int* __restrict__ grp_offsets,
const int* __restrict__ group_ids, float* __restrict__ grp_nz,
int* __restrict__ seg_begins, int* __restrict__ seg_ends, int n_all_grp,
int n_sort_groups, int sb_cols) {
int col = blockIdx.x;
int local = blockIdx.y;
if (col >= sb_cols || local >= n_sort_groups) return;
int g = group_ids[local];
int g_start = grp_offsets[g];
int n_grp = grp_offsets[g + 1] - g_start;
size_t base = (size_t)col * n_all_grp + g_start;
int f = col * n_sort_groups + local;

__shared__ int cnt;
if (threadIdx.x == 0) cnt = 0;
__syncthreads();
for (int i = threadIdx.x; i < n_grp; i += blockDim.x) {
float v = grp_dense[base + i];
if (v > 0.0f) grp_nz[base + atomicAdd(&cnt, 1)] = v;
}
__syncthreads();
if (threadIdx.x == 0) {
seg_begins[f] = (int)base;
seg_ends[f] = (int)base + cnt;
}
}

// Rank HUGE-band groups from sorted positives plus the zero block.
__global__ void ovo_rank_huge_analytic_kernel(
const float* __restrict__ ref_sorted,
const float* __restrict__ grp_nz_sorted,
const int* __restrict__ grp_offsets, const int* __restrict__ group_ids,
const int* __restrict__ seg_begins, const int* __restrict__ seg_ends,
const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums,
double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols,
int n_sort_groups, bool compute_tie_corr) {
int col = blockIdx.x;
int local = blockIdx.y;
if (col >= n_cols || local >= n_sort_groups) return;
int grp = group_ids[local];
int g_start = grp_offsets[grp];
int n_grp = grp_offsets[grp + 1] - g_start;
int f = col * n_sort_groups + local;
int b = seg_begins[f];
int nnz = seg_ends[f] - b;
const float* nz = grp_nz_sorted + b;
const float* ref_col = ref_sorted + (long long)col * n_ref;

__shared__ double warp_buf[WARP_REDUCE_BUF];
__shared__ int sh_ref_zeros;
if (threadIdx.x == 0)
sh_ref_zeros = sorted_upper_bound(ref_col, 0, n_ref, 0.0f);
__syncthreads();
int ref_zeros = sh_ref_zeros;
int n_grp_zero = n_grp - nnz;
int total_zero = ref_zeros + n_grp_zero;
double zero_rank =
(total_zero > 0) ? ((double)total_zero + 1.0) / 2.0 : 0.0;

double local_sum =
(threadIdx.x == 0) ? (double)n_grp_zero * zero_rank : 0.0;
int ref_lb = 0, ref_ub = 0, grp_lb = 0, grp_ub = 0;
for (int i = threadIdx.x; i < nnz; i += blockDim.x) {
OvoRank r = ovo_mid_rank(ref_col, n_ref, nz, nnz, nz[i], ref_lb, ref_ub,
grp_lb, grp_ub);
local_sum += r.mid_rank + (double)n_grp_zero;
}
double total = wilcoxon_block_sum(local_sum, warp_buf);
if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total;

if (!compute_tie_corr) return;
__syncthreads();
double local_tie = 0.0;
for (int i = threadIdx.x; i < nnz; i += blockDim.x) {
if (i == 0 || nz[i] != nz[i - 1]) {
float v = nz[i];
int gub = sorted_upper_bound(nz, i + 1, nnz, v);
double cg = (double)(gub - i);
int rlo = sorted_lower_bound(ref_col, 0, n_ref, v);
int rub = sorted_upper_bound(ref_col, rlo, n_ref, v);
double cr = (double)(rub - rlo);
double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0;
local_tie += group_tie;
if (cr > 0.0) {
double comb = cr + cg;
double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0;
local_tie += comb * comb * comb - comb - ref_tie - group_tie;
}
}
}
double tie = wilcoxon_block_sum(local_tie, warp_buf);
if (threadIdx.x == 0) {
double zd = 0.0;
if (total_zero > 1)
zd += (double)total_zero * total_zero * total_zero - total_zero;
if (ref_zeros > 1)
zd -= (double)ref_zeros * ref_zeros * ref_zeros - ref_zeros;
tie_corr[grp * n_cols + col] =
finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie + zd);
}
}

// MEDIUM tie helper: sorted-reference contribution, one block per column.
// Rank kernels add only group-only/ref-overlap deltas.
__global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted,
Expand Down Expand Up @@ -318,6 +530,93 @@ __global__ void ovo_rank_medium_kernel(
finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie_delta);
}

// WARP/SMALL tiers were removed; MEDIUM now covers all groups <=
// OVO_MEDIUM_MAX. Restore notes live in
// .claude/wilcoxon-warp-small-tiers-removed.md.
// MEDIUM analytic-zero path for nonnegative sparse data.
__global__ void ovo_rank_medium_analytic_kernel(
const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense,
const int* __restrict__ grp_offsets,
const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums,
double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols,
int n_groups, bool compute_tie_corr, int skip_n_grp_le, int max_n_grp_le) {
int col = blockIdx.x;
int grp = blockIdx.y;
if (col >= n_cols || grp >= n_groups) return;

int g_start = grp_offsets[grp];
int n_grp = grp_offsets[grp + 1] - g_start;
if (n_grp <= skip_n_grp_le || n_grp > max_n_grp_le) return;

extern __shared__ char smem_raw[];
float* grp_smem = (float*)smem_raw;
double* warp_buf = (double*)(smem_raw + max_n_grp_le * sizeof(float));
__shared__ int sh_nnz;
__shared__ int sh_ref_zeros;

const float* ref_col = ref_sorted + (long long)col * n_ref;
const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start;
if (threadIdx.x == 0) {
sh_nnz = 0;
sh_ref_zeros = sorted_upper_bound(ref_col, 0, n_ref, 0.0f);
}
__syncthreads();
for (int i = threadIdx.x; i < n_grp; i += blockDim.x) {
float v = grp_col[i];
if (v > 0.0f) grp_smem[atomicAdd(&sh_nnz, 1)] = v;
}
__syncthreads();
int nnz = sh_nnz;
int ref_zeros = sh_ref_zeros;
int n_grp_zero = n_grp - nnz;
int total_zero = ref_zeros + n_grp_zero;
double zero_rank =
(total_zero > 0) ? ((double)total_zero + 1.0) / 2.0 : 0.0;

double local_sum =
(threadIdx.x == 0) ? (double)n_grp_zero * zero_rank : 0.0;
double local_tie = 0.0;
for (int i = threadIdx.x; i < nnz; i += blockDim.x) {
float v = grp_smem[i];
int n_lt_ref = sorted_lower_bound(ref_col, 0, n_ref, v);
int n_eq_ref =
sorted_upper_bound(ref_col, n_lt_ref, n_ref, v) - n_lt_ref;
int n_lt_grp = 0;
int n_eq_grp = 0;
bool first_in_grp = true;
for (int j = 0; j < nnz; ++j) {
float w = grp_smem[j];
if (w < v) ++n_lt_grp;
if (w == v) {
++n_eq_grp;
if (j < i) first_in_grp = false;
}
}
local_sum += (double)(n_lt_ref + n_lt_grp + n_grp_zero) +
((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0;
if (compute_tie_corr && first_in_grp) {
double cg = (double)n_eq_grp;
double cr = (double)n_eq_ref;
double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0;
local_tie += group_tie;
if (cr > 0.0) {
double combined = cr + cg;
double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0;
local_tie += combined * combined * combined - combined -
ref_tie - group_tie;
}
}
}
double total = wilcoxon_block_sum(local_sum, warp_buf);
if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total;

if (!compute_tie_corr) return;
__syncthreads();
double tie = wilcoxon_block_sum(local_tie, warp_buf);
if (threadIdx.x == 0) {
double zd = 0.0;
if (total_zero > 1)
zd += (double)total_zero * total_zero * total_zero - total_zero;
if (ref_zeros > 1)
zd -= (double)ref_zeros * ref_zeros * ref_zeros - ref_zeros;
tie_corr[grp * n_cols + col] =
finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie + zd);
}
}
18 changes: 9 additions & 9 deletions src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,13 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref(
"dense OVO group offsets D2H");
auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups);
int max_grp_size = t1.max_grp_size;
bool run_large = t1.above_medium && t1.run_large;
bool run_huge = t1.above_medium && !run_large;
bool run_huge = compute_tie_corr && t1.run_huge;

std::vector<int> h_sort_group_ids;
int n_sort_groups = n_groups;
if (run_huge) {
h_sort_group_ids =
make_sort_group_ids(h_offsets.data(), n_groups, OVO_MEDIUM_MAX);
make_sort_group_ids(h_offsets.data(), n_groups, t1.huge_skip_le);
n_sort_groups = (int)h_sort_group_ids.size();
}

Expand Down Expand Up @@ -507,7 +506,8 @@ static void launch_ovo_rank_dense_tiered_unsorted_ref(
ovo_dispatch_tiers(ref_sub, grp_sub, grp_offsets, t1, sc,
d_sort_group_ids, n_sort_groups, grp_cub_temp_bytes,
sb_grp_items_actual, tpb_rank, n_ref, n_all_grp,
sb_cols, n_groups, compute_tie_corr, stream);
sb_cols, n_groups, compute_tie_corr,
/*analytic_zeros=*/false, stream);

cuda_check(
cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double),
Expand Down Expand Up @@ -552,14 +552,13 @@ static void launch_ovo_rank_dense_host_streaming(

auto tier_plan = make_ovo_tier_plan(h_grp_offsets, n_groups);
int max_grp_size = tier_plan.max_grp_size;
bool run_large = tier_plan.above_medium && tier_plan.run_large;
bool run_huge = tier_plan.above_medium && !run_large;
bool run_huge = compute_tie_corr && tier_plan.run_huge;

std::vector<int> h_sort_group_ids;
int n_sort_groups = n_groups;
if (run_huge) {
h_sort_group_ids =
make_sort_group_ids(h_grp_offsets, n_groups, OVO_MEDIUM_MAX);
h_sort_group_ids = make_sort_group_ids(h_grp_offsets, n_groups,
tier_plan.huge_skip_le);
n_sort_groups = (int)h_sort_group_ids.size();
}

Expand Down Expand Up @@ -794,7 +793,8 @@ static void launch_ovo_rank_dense_host_streaming(
ovo_dispatch_tiers(ref_sub, grp_sub, d_grp_offsets, tier_plan, sc,
d_sort_group_ids, n_sort_groups, grp_cub_temp_bytes,
sb_grp_items_actual, tpb_rank, n_ref, n_all_grp,
sb_cols, n_groups, compute_tie_corr, stream);
sb_cols, n_groups, compute_tie_corr,
/*analytic_zeros=*/false, stream);

cuda_check(
cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double),
Expand Down
Loading