diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 3f24105b..4f824ec6 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,6 +1,5 @@ -# This workflow will build two Docker image and push then to GitHub Packages Container registry: -# - a base image with the dependencies -# - a main image with the application code +# Build/push two GHCR images: dependency base and application image. +# Release events push; PR/comment runs only validate. name: Docker @@ -73,8 +72,8 @@ jobs: RAPIDS_VER: - "26.04" CUDA_SUFFIX: - - { ver: "12.8.0", label: "cuda12", pkg: "cu12" } - - { ver: "13.0.2", label: "cuda13", pkg: "cu13" } + - { ver: "12.9.1", label: "cuda12", pkg: "cu12" } + - { ver: "13.1.0", label: "cuda13", pkg: "cu13" } name: Build Docker images (${{ matrix.CUDA_SUFFIX.label }}) runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e3a5fde0..9b086988 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -69,16 +69,46 @@ jobs: path = pathlib.Path("pyproject.toml") text = path.read_text() + def remove_toml_array(text, key): + lines = text.splitlines(keepends=True) + out = [] + i = 0 + while i < len(lines): + if lines[i].startswith(f"{key} = ["): + depth = lines[i].count("[") - lines[i].count("]") + i += 1 + while i < len(lines) and depth > 0: + depth += lines[i].count("[") - lines[i].count("]") + i += 1 + continue + out.append(lines[i]) + i += 1 + return "".join(out) + # Rename package text = text.replace( 'name = "rapids-singlecell"', f'name = "rapids-singlecell-cu{cuda}"', ) # Rename matching extra to "rapids", remove the other - text = text.replace(f'rapids-cu{cuda} =', 'rapids =') - # Remove the other CUDA extra line entirely - lines = text.splitlines(keepends=True) - text = "".join(l for l in lines if f'rapids-cu{other}' not in l) + text = text.replace(f'rapids-cu{cuda} = [', 'rapids = [') + text = remove_toml_array(text, f"rapids-cu{other}") + + # CMake links CUDA extensions against librmm. + # Add the matching wheel to isolated build requirements. + for dep in ( + f' "librmm-cu{other}>=25.12",\n', + f' "rmm-cu{other}>=25.12",\n', + ): + text = text.replace(dep, "") + rmm_build_req = f' "librmm-cu{cuda}>=25.12",\n' + build_system_text = text.split("[project]", 1)[0] + if f'"librmm-cu{cuda}>=25.12"' not in build_system_text: + text = text.replace( + ']\nbuild-backend = "scikit_build_core.build"', + f'{rmm_build_req}]\nbuild-backend = "scikit_build_core.build"', + 1, + ) # Set CUDA architectures (replace "native" with CI target archs) text = text.replace( @@ -96,6 +126,7 @@ jobs: - name: Sanity check pyproject.toml run: | + python3 -c "import tomllib; tomllib.load(open('pyproject.toml', 'rb'))" grep -E "name|rapids|CUDA_ARCH" pyproject.toml - name: Build CUDA manylinux image @@ -116,18 +147,25 @@ jobs: LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH PATH=/usr/local/cuda/bin:$PATH CIBW_BEFORE_BUILD: > + rm -f build/.librmm_dir && + mkdir -p build && python -m pip install -U pip scikit-build-core cmake ninja nanobind + librmm-cu${{ matrix.cuda_major }} && + RMM_ROOT=$(python -c "import librmm; print(librmm.__path__[0])") && + LOG_ROOT=$(python -c "import rapids_logger; print(rapids_logger.__path__[0])") && + echo "[rsc-build] librmm=$RMM_ROOT" && + echo "[rsc-build] rapids_logger=$LOG_ROOT" && + ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so && + ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && + ldconfig && + python -c "import librmm; print(librmm.__path__[0])" > build/.librmm_dir && + echo "[rsc-build] marker=$(cat build/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - # Exclude CUDA libs by SONAME glob (auditwheel >=6.2): the runtime - # stack (CuPy / nvidia-* wheels) provides them. Globs are version - # agnostic -- cusolver's SONAME is libcusolver.so.11 on CUDA 12 but - # .12 on CUDA 13, and nvJitLink is .12 vs .13, so pinning to the CUDA - # major would graft the wrong (or no) lib. cusolver's transitive deps - # (cublasLt, cusparse ~186MB, nvJitLink) are reached by auditwheel's - # tree walk and must each be excluded or they bloat the wheel. - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude 'libcublas.so.*' --exclude 'libcublasLt.so.*' --exclude 'libcudart.so.*' --exclude 'libcusolver.so.*' --exclude 'libcusparse.so.*' --exclude 'libnvJitLink.so.*' -w {dest_dir} {wheel}" + # Exclude CUDA/RAPIDS runtime libs provided by dependency wheels. + # Use SONAME globs so CUDA 12/13 suffix changes do not bundle them. + CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude 'libcublas.so.*' --exclude 'libcublasLt.so.*' --exclude 'libcudart.so.*' --exclude 'libcusolver.so.*' --exclude 'libcusparse.so.*' --exclude 'libnvJitLink.so.*' --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" CIBW_BUILD_VERBOSITY: "1" - uses: actions/upload-artifact@v7 diff --git a/.gitignore b/.gitignore index 749b2e8d..8158195d 100644 --- a/.gitignore +++ b/.gitignore @@ -54,3 +54,4 @@ AGENTS.md # tmp_scripts tmp_scripts/ +/benchmarks/ diff --git a/CMakeLists.txt b/CMakeLists.txt index f0d966c0..63fb3f01 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,130 @@ if (RSC_BUILD_EXTENSIONS) find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) find_package(nanobind CONFIG REQUIRED) find_package(CUDAToolkit REQUIRED) + set(RSC_RMM_HINTS) + set(RSC_RAPIDS_CMAKE_PREFIXES) + set(RSC_CCCL_HINTS) + set(RSC_RAPIDS_LOGGER_HINTS) + set(RSC_NVTX3_HINTS) + macro(_rsc_collect_rapids_python_prefix _rsc_prefix) + if (NOT "${_rsc_prefix}" STREQUAL "") + file(GLOB _rsc_rmm_dirs "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/cmake/rmm") + file(GLOB _rsc_rapids_prefixes + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64" + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/rapids" + "${_rsc_prefix}/lib/python*/site-packages/rapids_logger/lib64" + "${_rsc_prefix}/lib/python*/site-packages/nvidia/cu*/lib" + ) + file(GLOB _rsc_cccl_dirs + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/rapids/cmake/cccl" + "${_rsc_prefix}/lib/python*/site-packages/nvidia/cu*/lib/cmake/cccl" + ) + file(GLOB _rsc_rapids_logger_dirs "${_rsc_prefix}/lib/python*/site-packages/rapids_logger/lib64/cmake/rapids_logger") + file(GLOB _rsc_nvtx3_dirs "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/cmake/nvtx3") + list(APPEND RSC_RMM_HINTS ${_rsc_rmm_dirs}) + list(APPEND RSC_RAPIDS_CMAKE_PREFIXES ${_rsc_rapids_prefixes}) + list(APPEND RSC_CCCL_HINTS ${_rsc_cccl_dirs}) + list(APPEND RSC_RAPIDS_LOGGER_HINTS ${_rsc_rapids_logger_dirs}) + list(APPEND RSC_NVTX3_HINTS ${_rsc_nvtx3_dirs}) + endif() + endmacro() + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import importlib.util, pathlib; spec = importlib.util.find_spec('librmm'); print(pathlib.Path(spec.origin).parent / 'lib64' / 'cmake' / 'rmm' if spec else '')" + OUTPUT_VARIABLE RSC_PYTHON_RMM_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") + set(_rsc_python_rmm_hint "${RSC_PYTHON_RMM_DIR}") + else() + set(_rsc_python_rmm_hint "") + endif() + # Wheel builds write build/.librmm_dir from CIBW_BEFORE_BUILD. + # publish.yml symlinks runtime libs so auditwheel excludes them. + if(DEFINED ENV{RSC_LIBRMM_DIR} AND EXISTS "$ENV{RSC_LIBRMM_DIR}/lib64/cmake/rmm/rmm-config.cmake") + set(_rsc_librmm_marker "$ENV{RSC_LIBRMM_DIR}") + elseif(EXISTS "${CMAKE_SOURCE_DIR}/build/.librmm_dir") + file(READ "${CMAKE_SOURCE_DIR}/build/.librmm_dir" _rsc_librmm_marker) + string(STRIP "${_rsc_librmm_marker}" _rsc_librmm_marker) + else() + set(_rsc_librmm_marker "") + endif() + if(NOT "${_rsc_librmm_marker}" STREQUAL "" AND EXISTS "${_rsc_librmm_marker}/lib64/cmake/rmm/rmm-config.cmake") + file(GLOB _rsc_marker_rmm_dirs "${_rsc_librmm_marker}/lib64/cmake/rmm") + file(GLOB _rsc_marker_rapids_prefixes + "${_rsc_librmm_marker}/lib64" + "${_rsc_librmm_marker}/lib64/rapids" + "${_rsc_librmm_marker}/../rapids_logger/lib64" + ) + file(GLOB _rsc_marker_cccl_dirs + "${_rsc_librmm_marker}/lib64/rapids/cmake/cccl" + ) + file(GLOB _rsc_marker_rapids_logger_dirs "${_rsc_librmm_marker}/../rapids_logger/lib64/cmake/rapids_logger") + file(GLOB _rsc_marker_nvtx3_dirs "${_rsc_librmm_marker}/lib64/cmake/nvtx3") + list(APPEND RSC_RMM_HINTS ${_rsc_marker_rmm_dirs}) + list(APPEND RSC_RAPIDS_CMAKE_PREFIXES ${_rsc_marker_rapids_prefixes}) + list(APPEND RSC_CCCL_HINTS ${_rsc_marker_cccl_dirs}) + list(APPEND RSC_RAPIDS_LOGGER_HINTS ${_rsc_marker_rapids_logger_dirs}) + list(APPEND RSC_NVTX3_HINTS ${_rsc_marker_nvtx3_dirs}) + endif() + foreach(_rsc_python_prefix IN ITEMS "${Python_ROOT_DIR}" "${Python3_ROOT_DIR}") + _rsc_collect_rapids_python_prefix("${_rsc_python_prefix}") + endforeach() + foreach(_rsc_env_prefix IN ITEMS "$ENV{CONDA_PREFIX}" "$ENV{VIRTUAL_ENV}") + _rsc_collect_rapids_python_prefix("${_rsc_env_prefix}") + endforeach() + string(REPLACE ":" ";" _rsc_path_entries "$ENV{PATH}") + foreach(_rsc_path_entry IN LISTS _rsc_path_entries) + get_filename_component(_rsc_path_prefix "${_rsc_path_entry}/.." ABSOLUTE) + _rsc_collect_rapids_python_prefix("${_rsc_path_prefix}") + endforeach() + if (NOT RSC_RMM_HINTS + AND NOT "${_rsc_python_rmm_hint}" STREQUAL "") + list(APPEND RSC_RMM_HINTS "${_rsc_python_rmm_hint}") + endif() + if (RSC_RAPIDS_CMAKE_PREFIXES) + list(APPEND CMAKE_PREFIX_PATH ${RSC_RAPIDS_CMAKE_PREFIXES}) + if (RSC_CCCL_HINTS) + list(GET RSC_CCCL_HINTS 0 _rsc_cccl_dir) + set(CCCL_DIR "${_rsc_cccl_dir}" CACHE PATH "Path to CCCL package config" FORCE) + endif() + if (RSC_RAPIDS_LOGGER_HINTS) + list(GET RSC_RAPIDS_LOGGER_HINTS 0 _rsc_rapids_logger_dir) + set(rapids_logger_DIR "${_rsc_rapids_logger_dir}" CACHE PATH "Path to rapids_logger package config" FORCE) + endif() + if (RSC_NVTX3_HINTS) + list(GET RSC_NVTX3_HINTS 0 _rsc_nvtx3_dir) + set(nvtx3_DIR "${_rsc_nvtx3_dir}" CACHE PATH "Path to nvtx3 package config" FORCE) + endif() + endif() + if (RSC_RMM_HINTS) + list(GET RSC_RMM_HINTS 0 _rsc_rmm_dir) + set(rmm_DIR "${_rsc_rmm_dir}" CACHE PATH "Path to rmm package config" FORCE) + find_package(rmm CONFIG REQUIRED) + else() + find_package(rmm CONFIG REQUIRED) + endif() + + # CCCL 3.3.0 gates cudaDevAttrHostNumaMemoryPoolsSupported too loosely. + # Fail fast for CUDA 12.6-12.8 source builds with that buggy CCCL. + set(_rsc_cccl_buggy_numa_guard TRUE) + if (DEFINED CCCL_VERSION AND CCCL_VERSION VERSION_GREATER 3.3.0) + set(_rsc_cccl_buggy_numa_guard FALSE) + endif() + if (NOT RSC_SKIP_CUDA_VERSION_CHECK + AND _rsc_cccl_buggy_numa_guard + AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.6 + AND CUDAToolkit_VERSION VERSION_LESS 12.9) + message(FATAL_ERROR + "Cannot build rapids_singlecell from source with CUDA ${CUDAToolkit_VERSION} against " + "CCCL ${CCCL_VERSION} (RAPIDS 26.04): it references cudaDevAttrHostNumaMemoryPoolsSupported, " + "which the CUDA 12.6-12.8 toolkit does not define (NVIDIA added it in 12.9). " + "Use CUDA >= 12.9 (or <= 12.5), upgrade to RAPIDS >= 26.06 (CCCL > 3.3.0 fixes the guard), " + "or install the prebuilt wheel (pip install rapids-singlecell-cu12). " + "If your toolkit does define this enum, override with -DRSC_SKIP_CUDA_VERSION_CHECK=ON.") + endif() + + message(STATUS "Using RMM for CUDA extension scratch allocations") message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") @@ -62,6 +186,57 @@ function(add_nb_cuda_module target src) endif() endfunction() +# RMM-backed nanobind CUDA module: normal module plus shared scratch allocator. +# Wheels use sibling RAPIDS packages; editable imports still preload fallbacks. +function(add_rmm_cuda_module target src) + add_nb_cuda_module(${target} ${src}) + if (RSC_BUILD_EXTENSIONS) + target_sources(${target} PRIVATE + src/rapids_singlecell/_cuda/rmm_scratch.cu) + target_link_libraries(${target} PRIVATE rmm::rmm) + set(_rsc_rmm_build_rpath) + set(_rsc_rmm_have_build_librmm FALSE) + set(_rsc_rmm_have_build_rapids_logger FALSE) + if (DEFINED ENV{CONDA_PREFIX}) + set(_rsc_rmm_env_site + "$ENV{CONDA_PREFIX}/lib/python${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}/site-packages") + if (EXISTS "${_rsc_rmm_env_site}/librmm/lib64") + list(APPEND _rsc_rmm_build_rpath + "${_rsc_rmm_env_site}/librmm/lib64") + set(_rsc_rmm_have_build_librmm TRUE) + endif() + if (EXISTS "${_rsc_rmm_env_site}/rapids_logger/lib64") + list(APPEND _rsc_rmm_build_rpath + "${_rsc_rmm_env_site}/rapids_logger/lib64") + set(_rsc_rmm_have_build_rapids_logger TRUE) + endif() + endif() + if (NOT _rsc_rmm_have_build_librmm AND rmm_DIR) + get_filename_component(_rsc_rmm_build_librmm_dir + "${rmm_DIR}/../.." REALPATH) + list(APPEND _rsc_rmm_build_rpath "${_rsc_rmm_build_librmm_dir}") + endif() + if (NOT _rsc_rmm_have_build_rapids_logger AND rapids_logger_DIR) + get_filename_component(_rsc_rmm_build_rapids_logger_dir + "${rapids_logger_DIR}/../.." REALPATH) + list(APPEND _rsc_rmm_build_rpath + "${_rsc_rmm_build_rapids_logger_dir}") + endif() + set(_rsc_rmm_install_rpath + "\$ORIGIN/../../librmm/lib64" + "\$ORIGIN/../../rapids_logger/lib64" + ) + if (CUDAToolkit_LIBRARY_DIR) + list(APPEND _rsc_rmm_build_rpath "${CUDAToolkit_LIBRARY_DIR}") + list(APPEND _rsc_rmm_install_rpath "${CUDAToolkit_LIBRARY_DIR}") + endif() + set_target_properties(${target} PROPERTIES + BUILD_RPATH "${_rsc_rmm_build_rpath}" + INSTALL_RPATH "${_rsc_rmm_install_rpath}" + ) + endif() +endfunction() + if (RSC_BUILD_EXTENSIONS) # CUDA modules add_nb_cuda_module(_mean_var_cuda src/rapids_singlecell/_cuda/mean_var/mean_var.cu) @@ -91,7 +266,9 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_pseudobulk_cuda src/rapids_singlecell/_cuda/pseudobulk/pseudobulk.cu) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) - add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + add_rmm_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + add_rmm_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) + add_nb_cuda_module(_rank_stats_cuda src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) diff --git a/conda/rsc_rapids_26.04_cuda12.yml b/conda/rsc_rapids_26.04_cuda12.yml index 537b365a..f0010a8b 100644 --- a/conda/rsc_rapids_26.04_cuda12.yml +++ b/conda/rsc_rapids_26.04_cuda12.yml @@ -7,7 +7,7 @@ channels: dependencies: - rapids=26.04 - python=3.14 - - cuda-version=12.8 + - cuda-version=12.9 - cudnn - cutensor - cusparselt diff --git a/docker/Dockerfile b/docker/Dockerfile index cc533e46..344811a3 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,6 +5,11 @@ ARG GIT_ID=main SHELL ["/bin/bash", "-euo", "pipefail", "-c"] ENV PATH=/opt/conda/bin:$PATH +# Point CMake's find_package(rmm) at the conda env. The conda RAPIDS env resolved +# librmm + cuda-version together, so its librmm/rapids_logger headers match the +# image's CUDA toolkit. This is what lets the --no-build-isolation build below +# pick up the CUDA-matched librmm instead of a mismatched PyPI wheel. +ENV CMAKE_PREFIX_PATH=/opt/conda ARG CUDA_ARCHS="75-real;80-real;86-real;89-real;90-real;100-real;120" RUN < "cudaDevAttr* has no global scope" +# errors on both cu12 (toolkit older than the latest librmm) and cu13 (wrong +# cu12 variant). Install the PEP 517 backend deps first since isolation is off; +# the conda env already provides the librmm/rapids_logger headers + cmake config. +/opt/conda/bin/python -m pip install --no-cache-dir scikit-build-core nanobind setuptools-scm cmake ninja +/opt/conda/bin/python -m pip install --no-cache-dir --no-build-isolation -e . EOF diff --git a/docker/Dockerfile.deps b/docker/Dockerfile.deps index 6638a67d..aad3b0a5 100644 --- a/docker/Dockerfile.deps +++ b/docker/Dockerfile.deps @@ -1,4 +1,4 @@ -ARG CUDA_VER=13.0.2 +ARG CUDA_VER=13.1.0 ARG LINUX_VER=ubuntu24.04 FROM nvidia/cuda:${CUDA_VER}-devel-${LINUX_VER} @@ -7,7 +7,7 @@ SHELL ["/bin/bash", "-euo", "pipefail", "-c"] ARG PYTHON_VER=3.13 # Re-declare after FROM so it is available to RUN steps (passed by docker.yml build-args) -ARG CUDA_VER=13.0.2 +ARG CUDA_VER=13.1.0 ENV PATH=/opt/conda/bin:$PATH ENV PYTHON_VERSION=${PYTHON_VER} diff --git a/docker/docker-push.sh b/docker/docker-push.sh index 69801f79..4a137fa7 100755 --- a/docker/docker-push.sh +++ b/docker/docker-push.sh @@ -6,7 +6,7 @@ rapids_version=26.04 declare -A cuda_versions=( [cu12]="12.8.0" - [cu13]="13.0.2" + [cu13]="13.1.0" ) declare -A cuda_archs=( diff --git a/docs/contributing.md b/docs/contributing.md index f0542e57..e68011d2 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -7,6 +7,24 @@ - NVIDIA GPU with CUDA support - [micromamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html), conda/mamba, or [uv](https://docs.astral.sh/uv/) - A RAPIDS environment (e.g., conda `rapids-26.04` or pip-installed RAPIDS) +- **CUDA toolkit ≥ 12.9, or ≤ 12.5, for building from source** (see note below) + +```{important} +**On RAPIDS 26.04, building from source needs CUDA ≥ 12.9 (or ≤ 12.5) on CUDA 12.** +RAPIDS 26.04 ships CCCL 3.3.0, which references the `cudaDevAttrHostNumaMemoryPoolsSupported` +device attribute whenever the toolkit is ≥ 12.6, but NVIDIA only added that enum in +CUDA 12.9. So compiling the RMM/CCCL-using kernels (the Wilcoxon scratch allocator) +against a **CUDA 12.6–12.8** toolkit fails with +`error: the global scope has no "cudaDevAttrHostNumaMemoryPoolsSupported"`. + +This is an upstream CCCL guard bug, **fixed in CCCL > 3.3.0 (RAPIDS ≥ 26.06)** — so +the gap only affects RAPIDS 26.04. CUDA 13.x is unaffected. If you're on RAPIDS 26.04 ++ CUDA 12.6–12.8, either build with CUDA ≥ 12.9 (or ≤ 12.5), upgrade to RAPIDS ≥ 26.06, +or just use the **prebuilt wheel** (`pip install rapids-singlecell-cu12`) — wheels are +built on CUDA 12.2 (below the guard), so the enum is never referenced and they run fine +on any CUDA 12.x runtime, including 12.6–12.8. The build emits an actionable error in +this range; override only if your toolkit defines the enum with `-DRSC_SKIP_CUDA_VERSION_CHECK=ON`. +``` ### Clone and install diff --git a/docs/installation.md b/docs/installation.md index 9dd5deb3..a21f9dcd 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -65,6 +65,17 @@ pip install rapids-singlecell-cu12 This installs the precompiled CUDA kernels but **not** the RAPIDS stack (cupy, cuml, cudf, etc.). This is the recommended approach for **conda/mamba users** who already have RAPIDS installed in their environment. +```{note} +The RAPIDS stack is **required**, not optional: `rapids_singlecell` imports +`cuml`/`cupy` at the top of its package `__init__`, and the compiled kernels +(Wilcoxon, GMM, …) link `librmm` / `rapids_logger` at runtime. These are +provided by an existing RAPIDS conda/mamba environment or by the +`[rapids]`/`[rapids-cuXX]` extra below. Installing the bare +`rapids-singlecell-cuXX` wheel into an environment without RAPIDS raises an +`ImportError` on `import rapids_singlecell` itself — not merely when a kernel is +first used. +``` + ### Prebuilt wheels with RAPIDS dependencies To also install the RAPIDS stack via pip, use the `rapids` extra. @@ -102,6 +113,13 @@ pip install 'rapids-singlecell[rapids-cu12]' --extra-index-url=https://pypi.nvid ```{note} Building from source requires the CUDA toolkit (nvcc) and CMake >= 3.24 to be available in your environment. The nvcc/CUDAToolkit found during the build should match the RAPIDS/CuPy CUDA major runtime version in or linked to the environment. + +Isolated source builds (the default for `pip install rapids-singlecell` and the +`git+` installs below) pull `librmm-cu12` into the build environment regardless +of your local CUDA major. On a **CUDA 13** system this mismatches the toolkit, so +build inside an environment that already provides a matching `librmm` and pass +`--no-build-isolation` (e.g. `pip install --no-build-isolation "rapids-singlecell @ git+…"`) +so the build uses the environment's `librmm` instead of the cu12 wheel. ``` ### Install from GitHub diff --git a/docs/release-notes/0.15.3.md b/docs/release-notes/0.16.0.md similarity index 61% rename from docs/release-notes/0.15.3.md rename to docs/release-notes/0.16.0.md index 6b3a3f8c..fdf715e4 100644 --- a/docs/release-notes/0.15.3.md +++ b/docs/release-notes/0.16.0.md @@ -1,7 +1,9 @@ -### 0.15.3 {small}`the-future` +### 0.16.0 {small}`the-future` ```{rubric} Features ``` +* Reworked GPU {func}`~rapids_singlecell.tl.rank_genes_groups` Wilcoxon onto dedicated nanobind CUDA kernels {pr}`636` {smaller}`S Dicks` +* {func}`~rapids_singlecell.tl.rank_genes_groups` no longer truncates gene names longer than 50 characters in ``uns[...]['names']`` (the field is now ``object`` dtype, matching Scanpy) {pr}`636` {smaller}`S Dicks` * Add {class}`~rapids_singlecell.ptg.Mixscape` for GPU-accelerated Mixscape (`perturbation_signature`, `mixscape`, `mixscale`, `lda`) {pr}`688` {smaller}`S Dicks` ```{rubric} Performance diff --git a/docs/release-notes/index.md b/docs/release-notes/index.md index 329eb0ed..1f01cc8a 100644 --- a/docs/release-notes/index.md +++ b/docs/release-notes/index.md @@ -3,9 +3,11 @@ # Release notes -## Version 0.15.0 -```{include} /release-notes/0.15.3.md +## Version 0.16.0 +```{include} /release-notes/0.16.0.md ``` + +## Version 0.15.0 ```{include} /release-notes/0.15.2.md ``` ```{include} /release-notes/0.15.1.md diff --git a/pyproject.toml b/pyproject.toml index 33930ab8..56611f61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,9 @@ requires = [ "scikit-build-core>=0.10", "nanobind>=2.0.0", "setuptools-scm>=8", + # Wilcoxon links librmm at build time; generic isolated builds use CUDA 12. + # 25.12+ provides the resource-ref API and flat RMM header path we use. + "librmm-cu12>=25.12", ] build-backend = "scikit_build_core.build" @@ -32,8 +35,22 @@ dependencies = [ ] [project.optional-dependencies] -rapids-cu13 = [ "cupy-cuda13x", "cudf-cu13>=25.10", "cuml-cu13>=25.10", "cugraph-cu13>=25.10", "cuvs-cu13>=25.10" ] -rapids-cu12 = [ "cupy-cuda12x", "cudf-cu12>=25.10", "cuml-cu12>=25.10", "cugraph-cu12>=25.10", "cuvs-cu12>=25.10" ] +rapids-cu13 = [ + "cupy-cuda13x", + "cudf-cu13>=25.12", + "cuml-cu13>=25.12", + "cugraph-cu13>=25.12", + "cuvs-cu13>=25.12", + "librmm-cu13>=25.12", +] +rapids-cu12 = [ + "cupy-cuda12x", + "cudf-cu12>=25.12", + "cuml-cu12>=25.12", + "cugraph-cu12>=25.12", + "cuvs-cu12>=25.12", + "librmm-cu12>=25.12", +] doc = [ "sphinx>=4.5.0", @@ -150,8 +167,10 @@ sdist.include = [ "src/rapids_singlecell/_version.py" ] # Use abi3audit to catch issues with Limited API wheels [tool.cibuildwheel.linux] +# Exclude CUDA/RAPIDS runtime libs provided by dependency wheels. +# Keep in sync with CIBW_REPAIR_WHEEL_COMMAND in publish.yml. repair-wheel-command = [ - "auditwheel repair --exclude libcublas.so.12 --exclude libcublas.so.13 --exclude libcublasLt.so.12 --exclude libcublasLt.so.13 --exclude libcudart.so.12 --exclude libcudart.so.13 -w {dest_dir} {wheel}", + "auditwheel repair --exclude 'libcublas.so.*' --exclude 'libcublasLt.so.*' --exclude 'libcudart.so.*' --exclude 'libcusolver.so.*' --exclude 'libcusparse.so.*' --exclude 'libnvJitLink.so.*' --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}", "pipx run abi3audit --strict --report {wheel}", ] [tool.cibuildwheel.macos] diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index b897c42d..d4f70d12 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -5,8 +5,10 @@ operations. Each module is compiled from CUDA source files and exposed through nanobind bindings. -On systems without compiled extensions (e.g., docs builds), imports resolve -to None so that module-level imports don't raise ImportError. +On systems without compiled extensions (e.g., docs builds), a genuinely absent +module resolves to None so that module-level imports don't raise ImportError. A +module that is present but fails to load (ABI/toolkit mismatch, missing shared +library) is re-raised with context rather than silently swallowed. """ from __future__ import annotations @@ -43,18 +45,42 @@ "_pv_cuda", "_qc_cuda", "_qc_dask_cuda", + "_rank_stats_cuda", "_scale_cuda", "_sparse2dense_cuda", "_spca_cuda", "_wilcoxon_binned_cuda", "_wilcoxon_cuda", + "_wilcoxon_sparse_cuda", ] +def _preload_rapids_runtime_libs() -> None: + """Pre-load RAPIDS runtime libs so extension ``DT_NEEDED`` deps resolve.""" + for mod in ("librmm", "rapids_logger"): + try: + importlib.import_module(mod).load_library() + except (ImportError, OSError, AttributeError, RuntimeError): + pass + + +_preload_rapids_runtime_libs() + + def __getattr__(name: str): if name in __all__: try: return importlib.import_module(f".{name}", __name__) - except ImportError: + except ModuleNotFoundError: + # Extension genuinely absent (docs/no-GPU): degrade to None. return None + except ImportError as exc: + # Present but failed to load: surface ABI/toolkit/lib errors now. + # Returning None would cause a later cryptic attribute error. + msg = ( + f"Failed to load compiled CUDA extension {name!r}: {exc}. " + "Ensure a matching rapids-singlecell-cuXX wheel (and librmm) is " + "installed for your CUDA version." + ) + raise ImportError(msg) from exc raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index b2220ce8..dc27d4f1 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -8,9 +8,8 @@ namespace nb = nanobind; -/// Check the last CUDA error after a kernel launch. -/// Call immediately after every <<<...>>> launch to catch configuration errors -/// (invalid grid/block, shared memory overflow, etc.) before they propagate. +/// Check cudaGetLastError after a <<<...>>> launch (invalid grid/block, +/// shared memory overflow, etc.). inline void cuda_check_last_error(const char* kernel_name) { cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { @@ -21,13 +20,26 @@ inline void cuda_check_last_error(const char* kernel_name) { #define CUDA_CHECK_LAST_ERROR(kernel_name) cuda_check_last_error(#kernel_name) -/// Per-axis cached cap on `gridDim.{x,y,z}`. These differ in CUDA: -/// gridDim.x: 2^31-1 on CC 3.0+ -/// gridDim.y: 65535 on most GPUs -/// gridDim.z: 65535 -/// Newer hardware may relax these; we read at runtime and cache per device. -/// Returns a 3-element array indexed by 0=x, 1=y, 2=z. Multi-GPU safe via -/// thread-local cache keyed on the active device. +/// Check a cudaError_t returned directly by a CUDA/CUB API call. +/// Failed calls surface with a clear label instead of corrupted output later. +inline void cuda_check(cudaError_t err, const char* what) { + if (err != cudaSuccess) { + throw std::runtime_error(std::string(what) + + " failed: " + cudaGetErrorString(err)); + } +} + +/// Validate a binding-argument precondition (array dims vs. scalar shapes). +/// Mismatches become clean Python errors, not out-of-bounds launches. +inline void nb_require(bool cond, const char* what) { + if (!cond) { + throw std::invalid_argument( + std::string("rank_genes_groups CUDA binding: ") + what); + } +} + +/// Per-axis cached cap on `gridDim.{x,y,z}`; y/z are often only 65535. +/// Runtime per-device cache keeps this multi-GPU safe. inline const int* max_grid_dims() { static thread_local int cached_dev = -1; static thread_local int cached[3] = {65535, 65535, 65535}; // safe fallback @@ -54,15 +66,8 @@ inline int max_grid_dim_z() { return max_grid_dims()[2]; } -/// Grid-stride cap for kernels whose total work `nwork` (e.g. nnz, n_cells * -/// n_genes) may exceed what a single grid launch can cover. Pair with a -/// grid-strided loop inside the kernel: -/// -/// const long long stride = (long long)blockDim.x * gridDim.x; -/// for (long long i = ...; i < nwork; i += stride) { ... } -/// -/// Defaults to the `gridDim.x` cap. For 2D launches whose strided axis is y, -/// use `strided_grid_y`. Returns at least 1. +/// Grid-stride cap for kernels whose total work exceeds one grid launch. +/// Pair with a grid-strided loop; use `strided_grid_y` for y-axis launches. inline unsigned int strided_grid(long long nwork, int block_size) { const long long max_grid = max_grid_dim_x(); long long ideal = (nwork + block_size - 1) / block_size; @@ -70,8 +75,7 @@ inline unsigned int strided_grid(long long nwork, int block_size) { return (unsigned int)(capped < 1 ? 1 : capped); } -/// Like `strided_grid` but for the y-axis of a 2D/3D grid (much lower cap, -/// typically 65535). Use when the y dimension is the one being strided over. +/// Like `strided_grid` but for the y-axis (much lower cap, typically 65535). inline unsigned int strided_grid_y(long long nwork, int block_size) { const long long max_grid = max_grid_dim_y(); long long ideal = (nwork + block_size - 1) / block_size; @@ -80,9 +84,7 @@ inline unsigned int strided_grid_y(long long nwork, int block_size) { } // GPU array aliases for nanobind bindings, parameterized on device type. -// Bindings are registered for both nb::device::cuda (kDLCUDA = 2) and -// nb::device::cuda_managed (kDLCUDAManaged = 13) so that RMM managed-memory -// allocations are accepted without losing type safety for CPU arrays. +// CUDA and managed-memory variants both preserve CPU/GPU type safety. // C-contiguous (row-major) template @@ -92,19 +94,24 @@ using gpu_array_c = nb::ndarray; template using gpu_array_f = nb::ndarray; -// No contiguity constraint (accepts any order) +// No contiguity constraint template using gpu_array = nb::ndarray; -// Parameterized contiguity (for kernels that handle both C and F order) +// Parameterized contiguity (kernels handling both C and F order) template using gpu_array_contig = nb::ndarray; +// Host (NumPy) array aliases +template +using host_array = nb::ndarray>; +template +using host_array_c2 = nb::ndarray, nb::c_contig>; +template +using host_array_f2 = nb::ndarray, nb::f_contig>; + // Register bindings for both regular CUDA and managed-memory arrays. -// Usage: -// template -// void register_bindings(nb::module_& m) { ... } -// NB_MODULE(_foo_cuda, m) { REGISTER_GPU_BINDINGS(register_bindings, m); } +// Each registration function must be templated on `Device`. #define REGISTER_GPU_BINDINGS(func, module) \ func(module); \ func(module) diff --git a/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu new file mode 100644 index 00000000..6064e1bf --- /dev/null +++ b/src/rapids_singlecell/_cuda/rank_genes/rank_stats.cu @@ -0,0 +1,235 @@ +#include + +#include "../nb_types.h" +#include "../sparse_extract/sparse_extract.cuh" + +using namespace nb::literals; + +namespace { + +constexpr int GROUP_STATS_BLOCK = 256; + +// Benjamini-Hochberg tail: reverse cumulative min on sorted, BH-scaled rows. +// NaNs become 1.0; one serial thread per row. +__global__ void fdr_bh_reverse_cummin_kernel(double* values, const int n_cols) { + const int row = blockIdx.x; + double running = 1.0; + double* row_values = values + static_cast(row) * n_cols; + for (int col = n_cols - 1; col >= 0; --col) { + double value = row_values[col]; + if (!(value == value)) { // NaN -> 1.0 + value = 1.0; + } + if (value < running) { + running = value; + } + row_values[col] = running; + } +} + +// Per-group sum/sumsq/nnz over a dense F-order block; invalid groups are +// skipped. Outputs are C-order group x col and grid-strided beyond gridDim.x. +__global__ void group_chunk_stats_kernel( + const double* block, const int* group_codes, double* group_sums, + double* group_sum_sq, double* group_nnz, const int n_rows, const int n_cols, + const int n_groups, const bool compute_nnz) { + const long long total = static_cast(n_rows) * n_cols; + const long long stride = static_cast(blockDim.x) * gridDim.x; + for (long long idx = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + idx < total; idx += stride) { + const int row = idx % n_rows; + const int col = idx / n_rows; + const int group = group_codes[row]; + if (group < 0 || group >= n_groups) { + continue; + } + const double value = block[idx]; + const long long out = static_cast(group) * n_cols + col; + atomicAdd(group_sums + out, value); + atomicAdd(group_sum_sq + out, value * value); + if (compute_nnz && value != 0.0) { + atomicAdd(group_nnz + out, 1.0); + } + } +} + +} // namespace + +// CSR -> dense F-order (double) window densify, in a single fused pass. +template +static void def_csr_tile_to_dense(nb::module_& m) { + m.def( + "csr_tile_to_dense", + [](gpu_array_c indptr, + gpu_array_c indices, + gpu_array_c data, + gpu_array_f out, int col_lb, int col_ub, + std::uintptr_t stream) { + const int n_cells = static_cast(indptr.shape(0)) - 1; + if (n_cells <= 0 || col_ub <= col_lb) { + return; + } + if (col_lb < 0) { + throw std::invalid_argument( + "csr_tile_to_dense: col_lb must be non-negative"); + } + if (indices.shape(0) != data.shape(0)) { + throw std::invalid_argument( + "csr_tile_to_dense: indices and data must have equal " + "length"); + } + if (out.ndim() != 2 || static_cast(out.shape(0)) != n_cells || + static_cast(out.shape(1)) < + static_cast(col_ub) - col_lb) { + throw std::invalid_argument( + "csr_tile_to_dense: out must be a (n_cells, >= col_ub - " + "col_lb) array"); + } + constexpr int CSR_TILE_BLOCK = 128; + const unsigned int grid = + (static_cast(n_cells) + CSR_TILE_BLOCK - 1) / + CSR_TILE_BLOCK; + csr_tile_to_dense_kernel + <<>>( + indptr.data(), indices.data(), data.data(), out.data(), + col_lb, col_ub, n_cells); + CUDA_CHECK_LAST_ERROR(csr_tile_to_dense_kernel); + }, + "indptr"_a, "indices"_a, "data"_a, "out"_a, nb::kw_only(), "col_lb"_a, + "col_ub"_a, "stream"_a = 0); +} + +// CSC -> dense F-order (double) window densify, fused pass (column-major). +template +static void def_csc_tile_to_dense(nb::module_& m) { + m.def( + "csc_tile_to_dense", + [](gpu_array_c indptr, + gpu_array_c indices, + gpu_array_c data, + gpu_array_f out, int col_lb, int col_ub, + std::uintptr_t stream) { + const int n_cells = static_cast(out.shape(0)); + const int n_win = col_ub - col_lb; + if (n_cells <= 0 || n_win <= 0) { + return; + } + if (col_lb < 0) { + throw std::invalid_argument( + "csc_tile_to_dense: col_lb must be non-negative"); + } + if (indices.shape(0) != data.shape(0)) { + throw std::invalid_argument( + "csc_tile_to_dense: indices and data must have equal " + "length"); + } + if (out.ndim() != 2 || + static_cast(out.shape(1)) < n_win) { + throw std::invalid_argument( + "csc_tile_to_dense: out must be a (n_cells, >= col_ub - " + "col_lb) array"); + } + constexpr int CSC_TILE_BLOCK = 128; + csc_tile_to_dense_kernel + <<(n_win), CSC_TILE_BLOCK, 0, + (cudaStream_t)stream>>>(indptr.data(), indices.data(), + data.data(), out.data(), col_lb, + col_ub, n_cells); + CUDA_CHECK_LAST_ERROR(csc_tile_to_dense_kernel); + }, + "indptr"_a, "indices"_a, "data"_a, "out"_a, nb::kw_only(), "col_lb"_a, + "col_ub"_a, "stream"_a = 0); +} + +template +void register_bindings(nb::module_& m) { + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + def_csr_tile_to_dense(m); + + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + def_csc_tile_to_dense(m); + + m.def( + "fdr_bh_reverse_cummin", + [](gpu_array_c values, std::uintptr_t stream) { + const int n_rows = static_cast(values.shape(0)); + const int n_cols = static_cast(values.shape(1)); + if (n_rows <= 0 || n_cols <= 0) { + return; + } + fdr_bh_reverse_cummin_kernel<<>>( + values.data(), n_cols); + CUDA_CHECK_LAST_ERROR(fdr_bh_reverse_cummin_kernel); + }, + "values"_a, nb::kw_only(), "stream"_a = 0); + + m.def( + "group_chunk_stats", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c group_sums, + gpu_array_c group_sum_sq, + gpu_array_c group_nnz, bool compute_nnz, + std::uintptr_t stream) { + if (block.ndim() != 2 || group_sums.ndim() != 2 || + group_sum_sq.ndim() != 2) { + throw std::invalid_argument( + "group_chunk_stats: block, group_sums and group_sum_sq " + "must be 2-D"); + } + const int n_rows = static_cast(block.shape(0)); + const int n_cols = static_cast(block.shape(1)); + const int n_groups = static_cast(group_sums.shape(0)); + const long long total = static_cast(n_rows) * n_cols; + if (total <= 0) { + return; + } + if (static_cast(group_codes.shape(0)) != n_rows) { + throw std::invalid_argument( + "group_chunk_stats: group_codes length must equal block " + "rows"); + } + if (static_cast(group_sum_sq.shape(0)) != n_groups || + static_cast(group_sums.shape(1)) != n_cols || + static_cast(group_sum_sq.shape(1)) != n_cols) { + throw std::invalid_argument( + "group_chunk_stats: group_sums and group_sum_sq must be " + "(n_groups, n_cols)"); + } + if (compute_nnz && + (group_nnz.ndim() != 2 || + static_cast(group_nnz.shape(0)) != n_groups || + static_cast(group_nnz.shape(1)) != n_cols)) { + throw std::invalid_argument( + "group_chunk_stats: group_nnz must be (n_groups, n_cols) " + "when compute_nnz is set"); + } + const unsigned int grid = strided_grid(total, GROUP_STATS_BLOCK); + group_chunk_stats_kernel<<>>( + block.data(), group_codes.data(), group_sums.data(), + group_sum_sq.data(), group_nnz.data(), n_rows, n_cols, n_groups, + compute_nnz); + CUDA_CHECK_LAST_ERROR(group_chunk_stats_kernel); + }, + "block"_a, "group_codes"_a, "group_sums"_a, "group_sum_sq"_a, + "group_nnz"_a, nb::kw_only(), "compute_nnz"_a, "stream"_a = 0); +} + +NB_MODULE(_rank_stats_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/rmm_scratch.cu b/src/rapids_singlecell/_cuda/rmm_scratch.cu new file mode 100644 index 00000000..474e6227 --- /dev/null +++ b/src/rapids_singlecell/_cuda/rmm_scratch.cu @@ -0,0 +1,33 @@ +#include +#include +#include + +#include +#include + +#include "rmm_scratch.h" + +// Use the RMM resource-ref API; RMM 26.06 removed the raw-pointer accessor. +// The ref form compiles unchanged from RMM 25.12 through 26.06+. +void* rmm_allocate(size_t bytes) { + try { + return rmm::mr::get_current_device_resource_ref().allocate_sync(bytes); + } catch (std::exception const& e) { + throw std::runtime_error( + std::string("RMM scratch allocation failed (") + + std::to_string(bytes) + " bytes): " + e.what()); + } +} + +void rmm_deallocate(void* ptr, size_t bytes) { + rmm::mr::get_current_device_resource_ref().deallocate_sync(ptr, bytes); +} + +// Plain cudaMemGetInfo budget query, never a pool-probing trial allocation. +// Probing ratchets RMM pools and can starve non-pool allocations like streams. +size_t rmm_available_device_bytes(double fraction) { + if (fraction <= 0.0) return 0; + size_t free_b = 0, total_b = 0; + if (cudaMemGetInfo(&free_b, &total_b) != cudaSuccess) return 0; + return (size_t)(free_b * fraction); +} diff --git a/src/rapids_singlecell/_cuda/rmm_scratch.h b/src/rapids_singlecell/_cuda/rmm_scratch.h new file mode 100644 index 00000000..4b41fe15 --- /dev/null +++ b/src/rapids_singlecell/_cuda/rmm_scratch.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include + +// Shared RMM-backed device scratch (link rmm::rmm via add_rmm_cuda_module). +// Allocates from the current RMM resource, sharing CuPy/RAPIDS's pool. +void* rmm_allocate(size_t bytes); +void rmm_deallocate(void* ptr, size_t bytes); + +// fraction * cudaMemGetInfo free; never trial-probe a pool. +// Probing ratchets RMM pools and can starve cudaStreamCreate. +size_t rmm_available_device_bytes(double fraction); + +// Allocation pool for temporary CUDA buffers; frees everything on scope exit. +struct RmmScratchPool { + struct Allocation { + void* ptr = nullptr; + size_t bytes = 0; + }; + std::vector bufs; + + ~RmmScratchPool() { + for (Allocation alloc : bufs) { + if (!alloc.ptr) continue; + rmm_deallocate(alloc.ptr, alloc.bytes); + } + } + + template + T* alloc(size_t count) { + if (count == 0) count = 1; + if (count > std::numeric_limits::max() / sizeof(T)) { + throw std::runtime_error("RMM scratch allocation size overflow"); + } + size_t bytes = count * sizeof(T); + void* ptr = rmm_allocate(bytes); + bufs.push_back({ptr, bytes}); + return static_cast(ptr); + } +}; + +// Single RAII RMM device buffer (frees on scope exit). +struct ScopedCudaBuffer { + void* ptr = nullptr; + size_t bytes = 0; + + explicit ScopedCudaBuffer(size_t requested_bytes) { + bytes = requested_bytes == 0 ? 1 : requested_bytes; + ptr = rmm_allocate(bytes); + } + + ~ScopedCudaBuffer() { + if (!ptr) return; + rmm_deallocate(ptr, bytes); + } + + void* data() { + return ptr; + } + + ScopedCudaBuffer(const ScopedCudaBuffer&) = delete; + ScopedCudaBuffer& operator=(const ScopedCudaBuffer&) = delete; +}; diff --git a/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh new file mode 100644 index 00000000..81a2c519 --- /dev/null +++ b/src/rapids_singlecell/_cuda/sparse_extract/sparse_extract.cuh @@ -0,0 +1,170 @@ +#pragma once + +#include + +// Shared CSR/CSC extraction kernels for compact CSC and dense F-order tiles. +// Callers canonicalize/sort before kernels that binary-search row indices. + +// Scatter CSR nonzeros into compact CSC for columns [col_start, col_stop). +// `row_offset` rebases local row blocks; write_pos is atomically claimed. +template +__global__ void csr_scatter_to_csc_kernel( + const InT* __restrict__ data, const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, + InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + int col_start, int col_stop, int row_offset = 0) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + // Binary search for col_start (overflow-safe midpoint) + IndptrT lo = rs, hi = re; + while (lo < hi) { + IndptrT m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + for (IndptrT p = lo; p < re; ++p) { + int c = (int)indices[p]; + if (c >= col_stop) break; + int dest = atomicAdd(&write_pos[c - col_start], 1); + csc_vals[dest] = data[p]; + csc_row_idx[dest] = row_offset + row; + } +} + +// CSR column window [col_lb, col_ub) -> pre-zeroed dense F-order tile. +// atomicAdd preserves summed duplicate semantics for canonicalized CSR. +template +__global__ void csr_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, + const IndexT* __restrict__ indices, + const TData* __restrict__ data, + OutT* __restrict__ out, int col_lb, + int col_ub, int n_cells) { + const long long row = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (row >= n_cells) { + return; + } + const long long row_start = static_cast(indptr[row]); + const long long row_end = static_cast(indptr[row + 1]); + // Keep column ids in IndexT: narrowing a 64-bit IndexT to int would + // truncate large column ids and misplace writes. + const IndexT lb = static_cast(col_lb); + const IndexT ub = static_cast(col_ub); + for (long long k = row_start; k < row_end; ++k) { + const IndexT col = indices[k]; + if (col >= lb && col < ub) { + atomicAdd(&out[static_cast(col - lb) * n_cells + row], + static_cast(data[k])); + } + } +} + +// CSC column window [col_lb, col_ub) -> pre-zeroed dense F-order tile. +// No atomics: canonical CSC has one stored value per (col, row). +template +__global__ void csc_tile_to_dense_kernel(const IndptrT* __restrict__ indptr, + const IndexT* __restrict__ indices, + const TData* __restrict__ data, + OutT* __restrict__ out, int col_lb, + int col_ub, int n_cells) { + const int col = col_lb + static_cast(blockIdx.x); + if (col >= col_ub) return; + const long long col_local = blockIdx.x; + const IndptrT s = indptr[col]; + const IndptrT e = indptr[col + 1]; + for (IndptrT p = s + threadIdx.x; p < e; p += blockDim.x) { + const long long row = static_cast(indices[p]); + out[col_local * n_cells + row] = static_cast(data[p]); + } +} + +// CSR selected rows -> pre-zeroed dense F-order tile. +// Requires sorted row indices for binary-search + col_stop break. +template +__global__ void csr_extract_dense_kernel(const T* __restrict__ data, + const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, + const int* __restrict__ row_ids, + T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + + IndptrT lo = rs, hi = re; + while (lo < hi) { + IndptrT m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (IndptrT p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + tid] = data[p]; + } +} + +// CSR identity-mapped rows -> dense F-order; tolerates UNSORTED indices (full +// row scan, no binary search). One block per row. Output must be pre-zeroed. +template +__global__ void csr_extract_dense_identity_rows_unsorted_kernel( + const T* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int row = blockIdx.x; + if (row >= n_target) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + for (int p = rs + threadIdx.x; p < re; p += blockDim.x) { + int c = indices[p]; + if (c >= col_start && c < col_stop) { + out[(long long)(c - col_start) * n_target + row] = data[p]; + } + } +} + +// CSC selected rows -> pre-zeroed dense F-order tile. +// row_map[original_row] gives output row, or -1 to skip. +template +__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, + const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, + const int* __restrict__ row_map, + float* __restrict__ out, int n_target, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + IndptrT start = indptr[col]; + IndptrT end = indptr[col + 1]; + + for (IndptrT p = start + threadIdx.x; p < end; p += blockDim.x) { + int out_row = row_map[(int)indices[p]]; + if (out_row >= 0) { + out[(long long)col_local * n_target + out_row] = data[p]; + } + } +} + +// Narrowing element-wise cast, used only when input index width exceeds int32. +// Caller guarantees row/column positions fit the destination type. +template +__global__ void cast_array_kernel(const SrcT* __restrict__ src, + DstT* __restrict__ dst, size_t n) { + size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) dst[i] = (DstT)src[i]; +} diff --git a/src/rapids_singlecell/_cuda/streaming/streaming.cuh b/src/rapids_singlecell/_cuda/streaming/streaming.cuh new file mode 100644 index 00000000..ebb9f4c9 --- /dev/null +++ b/src/rapids_singlecell/_cuda/streaming/streaming.cuh @@ -0,0 +1,816 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../nb_types.h" +#include "../rmm_scratch.h" + +// Default thread-per-block for utility kernels shared by streaming pipelines. +constexpr int UTIL_BLOCK_SIZE = 256; +constexpr int DEFAULT_STREAMING_STREAMS = 4; +// Max per-batch nnz for segmented CUDA primitives that take int32 item counts. +constexpr size_t STREAMING_SAFE_BATCH_NNZ = 2000000000; // < INT_MAX +// Above this host span, avoid whole-array cudaHostRegister and use bounded +// staging. Moderate arrays keep the lower-overhead direct async-copy path. +constexpr size_t HOST_STREAMING_DIRECT_PIN_LIMIT_BYTES = + 16ULL * 1024ULL * 1024ULL * 1024ULL; + +// Host thread count for CPU-side staging passes: hardware concurrency, capped. +static inline int host_worker_count() { + unsigned hw = std::thread::hardware_concurrency(); + return (int)std::min(hw ? hw : 4u, 32u); +} + +// Run fn(chunk, r0, r1) over partitions of [0, n), serial for small n. +// Concurrent callers must use read-only shared state and disjoint outputs. +template +static inline int host_parallel_chunks(int n, F fn) { + if (n <= 0) return 0; + int n_threads = host_worker_count(); + if (n_threads <= 1 || n < 4096) { + fn(0, 0, n); + return 1; + } + int chunk = (n + n_threads - 1) / n_threads; + std::vector pool; + pool.reserve(n_threads); + for (int t = 0; t < n_threads; t++) { + int r0 = t * chunk; + if (r0 >= n) break; + int r1 = std::min(n, r0 + chunk); + pool.emplace_back([&fn, t, r0, r1]() { fn(t, r0, r1); }); + } + int used = (int)pool.size(); + for (std::thread& th : pool) th.join(); + return used; +} + +// Run fn(r0, r1) over a partition of [0, n) across hardware threads (serial for +// small n). Concurrent: read-only shared state, disjoint output ranges. +template +static inline void host_parallel_ranges(int n, F fn) { + host_parallel_chunks(n, [&fn](int, int r0, int r1) { fn(r0, r1); }); +} + +static inline int checked_cub_items(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds CUB int item limit"); + } + return (int)count; +} + +static inline int checked_int_span(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds int32 offset limit"); + } + return (int)count; +} + +static inline int checked_int_product(size_t a, size_t b, const char* context) { + if (a != 0 && b > (size_t)std::numeric_limits::max() / a) { + throw std::runtime_error(std::string(context) + + " exceeds int32 item limit"); + } + return (int)(a * b); +} + +template +struct SparseWindowDTypes { + using value_type = DeviceValueT; + using index_type = DeviceIndexT; + using accum_type = AccumT; + + static constexpr size_t bytes_per_nnz = + sizeof(value_type) + sizeof(index_type); +}; + +using WilcoxonSparseWindowDTypes = SparseWindowDTypes; + +template +static inline size_t sparse_window_nnz_bytes(size_t nnz) { + return nnz * DTypes::bytes_per_nnz; +} + +template +static inline size_t sparse_window_accum_bytes(size_t count) { + return count * sizeof(typename DTypes::accum_type); +} + +static inline void host_clear_id_map(int* id_map, int n_items) { + std::fill(id_map, id_map + n_items, -1); +} + +static inline void host_build_id_map(const int* ids, int n_ids, int* id_map, + int n_items, const char* what) { + host_clear_id_map(id_map, n_items); + for (int local = 0; local < n_ids; local++) { + int id = ids[local]; + if (id < 0 || id >= n_items) { + throw std::runtime_error(std::string(what) + + " id is out of bounds"); + } + id_map[id] = local; + } +} + +static inline void host_build_contiguous_id_map(int first, int count, + int* id_map, int n_items, + const char* what) { + if (first < 0 || count < 0 || first > n_items - count) { + throw std::runtime_error(std::string(what) + + " contiguous id window is out of bounds"); + } + host_clear_id_map(id_map, n_items); + for (int local = 0; local < count; local++) id_map[first + local] = local; +} + +// Stream-count clamps: never use more streams than column batches, nor more +// than the per-stream memory budget allows. +static inline int clamp_streams_by_cols( + int n_cols, int sub_batch_cols, + int max_streams = DEFAULT_STREAMING_STREAMS) { + int n = max_streams; + if (n_cols < n * sub_batch_cols) + n = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + return n; +} + +static inline int clamp_streams_by_budget(int n_streams, + size_t per_stream_bytes, + size_t budget) { + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + return n_streams; +} + +// Scatter row-major [rows, sb_cols] into destination stride n_cols. +// `dst` must already point at the destination column offset. +static inline void scatter_cols_2d(double* dst, const double* src, int rows, + int n_cols, int sb_cols, + cudaStream_t stream) { + cudaMemcpy2DAsync(dst, n_cols * sizeof(double), src, + sb_cols * sizeof(double), sb_cols * sizeof(double), rows, + cudaMemcpyDeviceToDevice, stream); +} + +// Halve sub_batch_cols until the densest window holds <= cap nonzeros. +// Keeps CUB item counts and per-stream scratch bounded; worst case returns 1. +template +static inline int cap_sub_batch_by_nnz(int n_cols, int sub_batch_cols, + size_t cap, ColNnz col_nnz) { + if (cap < 1) cap = 1; + auto max_window = [&](int s) { + size_t mx = 0; + for (int c = 0; c < n_cols; c += s) { + int e = std::min(c + s, n_cols); + size_t sum = 0; + for (int i = c; i < e; i++) sum += col_nnz(i); + if (sum > mx) mx = sum; + } + return mx; + }; + while (sub_batch_cols > 1 && max_window(sub_batch_cols) > cap) + sub_batch_cols = (sub_batch_cols + 1) / 2; + return sub_batch_cols; +} + +struct ColumnBatchPlan { + int sub_batch_cols = 0; + int n_batches = 0; + size_t max_nnz = 0; + std::vector offsets; + std::vector nnz; +}; + +struct HostCompactSparseWindowPlan { + int major_count = 0; + size_t nnz = 0; + std::vector indptr; +}; + +struct DenseColumnBatchPlan { + int sub_batch_cols = 0; + int n_batches = 0; + size_t max_items = 0; +}; + +static inline DenseColumnBatchPlan plan_dense_column_batches( + int n_rows, int n_cols, int sub_batch_cols, size_t cap, const char* what) { + DenseColumnBatchPlan plan; + if (sub_batch_cols < 1) sub_batch_cols = 1; + if (cap < 1) cap = 1; + checked_cub_items((size_t)n_rows, what); + + size_t max_cols = + n_rows > 0 ? cap / (size_t)n_rows : (size_t)sub_batch_cols; + if (max_cols < 1) max_cols = 1; + if ((size_t)sub_batch_cols > max_cols) sub_batch_cols = (int)max_cols; + + plan.sub_batch_cols = sub_batch_cols; + plan.n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + plan.max_items = (size_t)n_rows * (size_t)sub_batch_cols; + checked_cub_items(plan.max_items, what); + return plan; +} + +template +static inline ColumnBatchPlan plan_column_batches_from_counts( + int n_cols, int sub_batch_cols, size_t cap, CountAt count_at, + const char* what) { + ColumnBatchPlan plan; + plan.sub_batch_cols = + cap_sub_batch_by_nnz(n_cols, sub_batch_cols, cap, count_at); + plan.n_batches = (n_cols + plan.sub_batch_cols - 1) / plan.sub_batch_cols; + plan.offsets.assign((size_t)plan.n_batches * (plan.sub_batch_cols + 1), 0); + plan.nnz.assign(plan.n_batches, 0); + for (int b = 0; b < plan.n_batches; b++) { + int col_start = b * plan.sub_batch_cols; + int sb = std::min(plan.sub_batch_cols, n_cols - col_start); + int* off = &plan.offsets[(size_t)b * (plan.sub_batch_cols + 1)]; + for (int i = 0; i < sb; i++) { + off[i + 1] = checked_int_span( + (size_t)off[i] + (size_t)count_at(col_start + i), what); + } + plan.nnz[b] = (size_t)off[sb]; + if (plan.nnz[b] > plan.max_nnz) plan.max_nnz = plan.nnz[b]; + } + return plan; +} + +template +static inline ColumnBatchPlan plan_csc_column_batches(const IndptrT* h_indptr, + int n_cols, + int sub_batch_cols, + size_t cap, + const char* what) { + return plan_column_batches_from_counts( + n_cols, sub_batch_cols, cap, + [&](int c) { return (size_t)(h_indptr[c + 1] - h_indptr[c]); }, what); +} + +static inline int* upload_batch_offsets(const ColumnBatchPlan& plan, + RmmScratchPool& pool) { + int* d_all_offsets = pool.alloc(plan.offsets.size()); + cudaMemcpy(d_all_offsets, plan.offsets.data(), + plan.offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + return d_all_offsets; +} + +template +static HostCompactSparseWindowPlan plan_compact_sparse_window( + int major_count, CountAt count_at, const char* what) { + HostCompactSparseWindowPlan plan; + plan.major_count = major_count; + plan.indptr.assign((size_t)major_count + 1, 0); + if (major_count <= 0) return plan; + + std::vector counts(major_count, 0); + host_parallel_ranges(major_count, [&](int i0, int i1) { + for (int i = i0; i < i1; i++) counts[i] = count_at(i); + }); + + size_t run = 0; + for (int i = 0; i < major_count; i++) { + plan.indptr[i] = checked_int_span(run, what); + run += counts[i]; + } + plan.indptr[major_count] = checked_int_span(run, what); + plan.nnz = run; + return plan; +} + +template +static HostCompactSparseWindowPlan plan_csc_rows_window( + const IndexT* h_indices, const IndptrT* h_indptr, int col_start, + int n_window_cols, RowToLocal row_to_local, const char* what) { + return plan_compact_sparse_window( + n_window_cols, + [&](int local_col) { + int col = col_start + local_col; + size_t count = 0; + for (IndptrT p = h_indptr[col]; p < h_indptr[col + 1]; p++) { + if (row_to_local((int)h_indices[p]) >= 0) count++; + } + return count; + }, + what); +} + +template +static HostCompactSparseWindowPlan plan_csr_cols_window( + const IndexT* h_indices, const IndptrT* h_indptr, const int* row_ids, + int n_window_rows, ColToLocal col_to_local, const char* what) { + return plan_compact_sparse_window( + n_window_rows, + [&](int local_row) { + int row = row_ids ? row_ids[local_row] : local_row; + size_t count = 0; + for (IndptrT p = h_indptr[row]; p < h_indptr[row + 1]; p++) { + if (col_to_local((int)h_indices[p]) >= 0) count++; + } + return count; + }, + what); +} + +template +static HostCompactSparseWindowPlan plan_csc_rows_window_from_map( + const IndexT* h_indices, const IndptrT* h_indptr, int col_start, + int n_window_cols, const int* row_map, const char* what) { + return plan_csc_rows_window( + h_indices, h_indptr, col_start, n_window_cols, + [&](int row) { return row_map[row]; }, what); +} + +template +static HostCompactSparseWindowPlan plan_csr_cols_window_from_map( + const IndexT* h_indices, const IndptrT* h_indptr, const int* row_ids, + int n_window_rows, const int* col_map, const char* what) { + return plan_csr_cols_window( + h_indices, h_indptr, row_ids, n_window_rows, + [&](int col) { return col_map[col]; }, what); +} + +// RAII guard for cudaHostRegister: unregisters on scope exit (incl. exception +// unwind), preventing leaked host pinning on stream-sync failures. +struct HostRegisterGuard { + void* ptr = nullptr; + + HostRegisterGuard() = default; + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0, + bool best_effort = false) { + if (p && bytes > 0) { + cudaError_t err = cudaHostRegister(p, bytes, flags); + if (err != cudaSuccess) { + // Already-registered memory is owned elsewhere; use as-is. + // Other failures are fatal unless pinning is only a speedup. + if (err == cudaErrorHostMemoryAlreadyRegistered || + best_effort) { + cudaGetLastError(); // clear sticky error flag + } else { + throw std::runtime_error( + std::string("cudaHostRegister failed (") + + std::to_string((size_t)bytes) + + " bytes, flags=" + std::to_string(flags) + + "): " + cudaGetErrorString(err)); + } + } else { + ptr = p; + } + } + } + ~HostRegisterGuard() { + if (ptr) cudaHostUnregister(ptr); + } + HostRegisterGuard(const HostRegisterGuard&) = delete; + HostRegisterGuard& operator=(const HostRegisterGuard&) = delete; + HostRegisterGuard(HostRegisterGuard&& other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; + } + HostRegisterGuard& operator=(HostRegisterGuard&& other) noexcept { + if (this != &other) { + if (ptr) cudaHostUnregister(ptr); + ptr = other.ptr; + other.ptr = nullptr; + } + return *this; + } +}; + +// RAII for CUDA streams/events: stream destruction synchronizes first. +// Declare RmmScratchPool before guards so streams drain before scratch frees. +struct ScopedCudaStream { + cudaStream_t stream = nullptr; + + ScopedCudaStream() = default; + explicit ScopedCudaStream(unsigned int flags) { + cuda_check(cudaStreamCreateWithFlags(&stream, flags), + "cudaStreamCreateWithFlags"); + } + ~ScopedCudaStream() { + if (stream) { + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + } + } + operator cudaStream_t() const { + return stream; + } + cudaStream_t get() const { + return stream; + } + ScopedCudaStream(const ScopedCudaStream&) = delete; + ScopedCudaStream& operator=(const ScopedCudaStream&) = delete; +}; + +struct ScopedCudaStreams { + std::vector streams; + + // `flags` is explicit so call sites keep their original stream semantics. + ScopedCudaStreams(int n, unsigned int flags) { + streams.reserve(n > 0 ? (size_t)n : 0); + for (int i = 0; i < n; ++i) { + cudaStream_t s = nullptr; + cudaError_t err = cudaStreamCreateWithFlags(&s, flags); + if (err != cudaSuccess) { + // dtor won't run on ctor throw; reclaim what we made. + for (cudaStream_t prev : streams) { + cudaStreamSynchronize(prev); + cudaStreamDestroy(prev); + } + throw std::runtime_error( + std::string("cudaStreamCreateWithFlags failed: ") + + cudaGetErrorString(err)); + } + streams.push_back(s); + } + } + ~ScopedCudaStreams() { + for (cudaStream_t s : streams) { + if (!s) continue; + cudaStreamSynchronize(s); + cudaStreamDestroy(s); + } + } + cudaStream_t operator[](int i) const { + return streams[i]; + } + int size() const { + return (int)streams.size(); + } + ScopedCudaStreams(const ScopedCudaStreams&) = delete; + ScopedCudaStreams& operator=(const ScopedCudaStreams&) = delete; +}; + +// Drain every stream, surfacing the first async error with a context label. +static inline void sync_streams(const ScopedCudaStreams& streams, + const char* what) { + for (int i = 0; i < streams.size(); ++i) { + cudaError_t err = cudaStreamSynchronize(streams[i]); + if (err != cudaSuccess) + throw std::runtime_error(std::string("CUDA error in ") + what + + ": " + cudaGetErrorString(err)); + } +} + +struct ScopedCudaEvent { + cudaEvent_t event = nullptr; + + ScopedCudaEvent() = default; + explicit ScopedCudaEvent(unsigned int flags) { + cuda_check(cudaEventCreateWithFlags(&event, flags), + "cudaEventCreateWithFlags"); + } + ~ScopedCudaEvent() { + if (event) cudaEventDestroy(event); + } + void record(cudaStream_t stream) { + cuda_check(cudaEventRecord(event, stream), "cudaEventRecord"); + } + cudaEvent_t get() const { + return event; + } + ScopedCudaEvent(const ScopedCudaEvent&) = delete; + ScopedCudaEvent& operator=(const ScopedCudaEvent&) = delete; +}; + +template +struct PinnedRingArray { + std::vector> data; + std::vector pins; + + PinnedRingArray() = default; + PinnedRingArray(int n_slots, size_t count) : data(n_slots), pins(n_slots) { + size_t n = count ? count : 1; + for (int s = 0; s < n_slots; s++) { + data[s].reset(new T[n]); + pins[s] = HostRegisterGuard(data[s].get(), n * sizeof(T)); + } + } + T* get(int slot) { + return data[slot].get(); + } + const T* get(int slot) const { + return data[slot].get(); + } +}; + +// Per-slot pinned host staging with events for CPU/GPU overlap. +// Arrays share item capacity; use another ring for differently-sized metadata. +template +struct PinnedRing { + std::tuple...> arrays; + std::vector evt; + std::vector used; + int n_slots = 0; + size_t capacity = 0; + + PinnedRing(int n_slots_, size_t count) + : arrays(PinnedRingArray(n_slots_, count)...), + evt(n_slots_, nullptr), + used(n_slots_, 0) { + n_slots = n_slots_; + capacity = count ? count : 1; + for (int s = 0; s < n_slots; s++) { + cuda_check( + cudaEventCreateWithFlags(&evt[s], cudaEventDisableTiming), + "PinnedRing event create"); + } + } + ~PinnedRing() { + for (size_t s = 0; s < evt.size(); ++s) { + cudaEvent_t e = evt[s]; + if (!e) continue; + if (s < used.size() && used[s]) cudaEventSynchronize(e); + cudaEventDestroy(e); + } + } + void wait(int s) { + if (used[s]) + cuda_check(cudaEventSynchronize(evt[s]), "PinnedRing reuse"); + } + void record(int s, cudaStream_t stream) { + cuda_check(cudaEventRecord(evt[s], stream), "PinnedRing record"); + used[s] = true; + } + template + typename std::tuple_element>::type* get(int slot) { + return std::get(arrays).get(slot); + } + template + const typename std::tuple_element>::type* get( + int slot) const { + return std::get(arrays).get(slot); + } + PinnedRing(const PinnedRing&) = delete; + PinnedRing& operator=(const PinnedRing&) = delete; +}; + +template +using SparseWindowStagingRing = + PinnedRing; + +using HostStagingRing = SparseWindowStagingRing; + +/** Fill linear segment offsets [0, stride, ..., n_segments*stride] on-device. + */ +__global__ void fill_linear_offsets_kernel(int* __restrict__ out, + int n_segments, int stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i <= n_segments) out[i] = i * stride; +} + +/** Rebase indptr slice to a local origin, grid-strided for arbitrary count. + * 64-bit global indptrs may produce 32-bit pack-local indptrs. */ +template +__global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, + IdxOut* __restrict__ out, int col, + int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); +} + +// Threaded selected-row gather into compact staging at disjoint offsets. +// No-pin alternative: only the compacted slice crosses the bus. +template +static void host_gather_rows_compact_as( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* row_ids, const CompactT* compact_indptr, CompactT base, + int n_target, StageValT* stage_vals, StageIndexT* stage_cols) { + host_parallel_ranges(n_target, [&](int i0, int i1) { + for (int i = i0; i < i1; i++) { + int r = row_ids[i]; + IndptrT rs = h_indptr[r]; + int nnz = (int)(h_indptr[r + 1] - rs); + size_t ds = (size_t)(compact_indptr[i] - base); + for (int k = 0; k < nnz; k++) { + stage_vals[ds + k] = (StageValT)h_data[rs + k]; + stage_cols[ds + k] = (StageIndexT)h_indices[rs + k]; + } + } + }); +} + +template +static void host_gather_rows_compact(const InT* h_data, const IndexT* h_indices, + const IndptrT* h_indptr, + const int* row_ids, + const CompactT* compact_indptr, + CompactT base, int n_target, + float* stage_vals, int* stage_cols) { + host_gather_rows_compact_as(h_data, h_indices, h_indptr, + row_ids, compact_indptr, base, + n_target, stage_vals, stage_cols); +} + +// Threaded host cast-copy of a contiguous nnz slice into staging. +// CSC analogue of row gather: contiguous column batch, bounded int32 nnz. +template +static void host_copy_slice_as(const InT* h_data, const IndexT* h_indices, + size_t start, int nnz, StageValT* stage_vals, + StageIndexT* stage_cols) { + host_parallel_ranges(nnz, [&](int k0, int k1) { + for (int k = k0; k < k1; k++) { + stage_vals[k] = (StageValT)h_data[start + k]; + stage_cols[k] = (StageIndexT)h_indices[start + k]; + } + }); +} + +template +static void host_copy_slice(const InT* h_data, const IndexT* h_indices, + size_t start, int nnz, InT* stage_vals, + IndexT* stage_cols) { + host_copy_slice_as(h_data, h_indices, start, nnz, stage_vals, + stage_cols); +} + +template +static void host_cast_copy_slice(const InT* h_data, const IndexT* h_indices, + size_t start, int nnz, float* stage_vals, + int* stage_cols) { + host_copy_slice_as(h_data, h_indices, start, nnz, stage_vals, + stage_cols); +} + +// Threaded host gather of selected dense rows and contiguous columns. +// Output staging is always F-order [n_window_rows, n_window_cols]. +template +static void host_materialize_dense_rows_window_as( + const InT* h_X, bool f_order, int n_full_rows, int n_full_cols, + const int* row_ids, int n_window_rows, int col_start, int n_window_cols, + StageT* stage) { + int total = + checked_int_product((size_t)n_window_rows, (size_t)n_window_cols, + "dense host row-window items"); + host_parallel_ranges(total, [&](int i0, int i1) { + for (int idx = i0; idx < i1; idx++) { + int local_col = idx / n_window_rows; + int local_row = idx - local_col * n_window_rows; + int row = row_ids ? row_ids[local_row] : local_row; + int col = col_start + local_col; + size_t src = f_order ? (size_t)col * n_full_rows + row + : (size_t)row * n_full_cols + col; + stage[(size_t)local_col * n_window_rows + local_row] = + (StageT)h_X[src]; + } + }); +} + +template +static void host_materialize_dense_rows_window(const InT* h_X, bool f_order, + int n_full_rows, int n_full_cols, + const int* row_ids, + int n_window_rows, int col_start, + int n_window_cols, InT* stage) { + host_materialize_dense_rows_window_as( + h_X, f_order, n_full_rows, n_full_cols, row_ids, n_window_rows, + col_start, n_window_cols, stage); +} + +// Cross-axis CSC materialization: filter a contiguous column window by selected +// rows and emit compact CSC with local row ids. +template +static void host_materialize_csc_rows_window_as( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int col_start, int n_window_cols, const int* compact_indptr, + RowToLocal row_to_local, StageValT* stage_vals, StageIndexT* stage_rows) { + host_parallel_ranges(n_window_cols, [&](int c0, int c1) { + for (int local_col = c0; local_col < c1; local_col++) { + int col = col_start + local_col; + size_t dst = (size_t)compact_indptr[local_col]; + for (IndptrT p = h_indptr[col]; p < h_indptr[col + 1]; p++) { + int local_row = row_to_local((int)h_indices[p]); + if (local_row < 0) continue; + stage_vals[dst] = (StageValT)h_data[p]; + stage_rows[dst] = (StageIndexT)local_row; + dst++; + } + } + }); +} + +template +static void host_materialize_csc_rows_window( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int col_start, int n_window_cols, const int* compact_indptr, + const int* row_map, float* stage_vals, int* stage_rows) { + host_materialize_csc_rows_window_as( + h_data, h_indices, h_indptr, col_start, n_window_cols, compact_indptr, + [&](int row) { return row_map[row]; }, stage_vals, stage_rows); +} + +// Cross-axis CSR materialization: filter selected rows by selected columns and +// emit compact CSR with local column ids. +template +static void host_materialize_csr_cols_window_as( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* row_ids, int n_window_rows, const int* compact_indptr, + ColToLocal col_to_local, StageValT* stage_vals, StageIndexT* stage_cols) { + host_parallel_ranges(n_window_rows, [&](int r0, int r1) { + for (int local_row = r0; local_row < r1; local_row++) { + int row = row_ids ? row_ids[local_row] : local_row; + size_t dst = (size_t)compact_indptr[local_row]; + for (IndptrT p = h_indptr[row]; p < h_indptr[row + 1]; p++) { + int local_col = col_to_local((int)h_indices[p]); + if (local_col < 0) continue; + stage_vals[dst] = (StageValT)h_data[p]; + stage_cols[dst] = (StageIndexT)local_col; + dst++; + } + } + }); +} + +template +static void host_materialize_csr_cols_window( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* row_ids, int n_window_rows, const int* compact_indptr, + const int* col_map, float* stage_vals, int* stage_cols) { + host_materialize_csr_cols_window_as( + h_data, h_indices, h_indptr, row_ids, n_window_rows, compact_indptr, + [&](int col) { return col_map[col]; }, stage_vals, stage_cols); +} + +// Optimized CSR -> contiguous-column-window materialization for sorted rows. +// The per-row cursor examines each nonzero once across the full stream. +template +static int host_materialize_csr_column_interval_cursor_as( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_rows, int col_start, int col_end, IndptrT* cursor, int* row_counts, + int* compact_indptr, StageValT* stage_vals, StageIndexT* stage_cols, + const char* what) { + host_parallel_ranges(n_rows, [&](int r0, int r1) { + for (int r = r0; r < r1; r++) { + const IndexT* row_base = h_indices + h_indptr[r]; + const IndexT* lo = row_base + cursor[r]; + const IndexT* hi = h_indices + h_indptr[r + 1]; + if (lo < hi && *lo < (IndexT)col_start) { + lo = std::lower_bound(lo, hi, (IndexT)col_start); + cursor[r] = (IndptrT)(lo - row_base); + } + row_counts[r] = + (int)(std::lower_bound(lo, hi, (IndexT)col_end) - lo); + } + }); + + compact_indptr[0] = 0; + for (int r = 0; r < n_rows; r++) { + compact_indptr[r + 1] = checked_int_span( + (size_t)compact_indptr[r] + (size_t)row_counts[r], what); + } + int batch_nnz = compact_indptr[n_rows]; + + host_parallel_ranges(n_rows, [&](int r0, int r1) { + for (int r = r0; r < r1; r++) { + IndptrT base = h_indptr[r] + cursor[r]; + size_t dst = (size_t)compact_indptr[r]; + int count = row_counts[r]; + for (int k = 0; k < count; k++) { + stage_vals[dst + k] = (StageValT)h_data[base + k]; + stage_cols[dst + k] = (StageIndexT)h_indices[base + k]; + } + cursor[r] += count; + } + }); + return batch_nnz; +} + +template +static int host_materialize_csr_column_interval_cursor( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_rows, int col_start, int col_end, IndptrT* cursor, int* row_counts, + int* compact_indptr, InT* stage_vals, int* stage_cols, const char* what) { + return host_materialize_csr_column_interval_cursor_as( + h_data, h_indices, h_indptr, n_rows, col_start, col_end, cursor, + row_counts, compact_indptr, stage_vals, stage_cols, what); +} + +/** Fill linear segment offsets [0, stride, ...] on the supplied stream (avoids + * serializing multi-stream pipelines). */ +static inline void upload_linear_offsets(int* d_offsets, int n_segments, + int stride, cudaStream_t stream) { + int count = n_segments + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_linear_offsets_kernel<<>>( + d_offsets, n_segments, stride); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index c89d913a..2e963df8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -2,143 +2,60 @@ #include -/** - * Kernel to compute tie correction factor for Wilcoxon test. - * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied - * values. - * - * Each block handles one column. Uses binary search to find tie groups. - * Assumes input is sorted column-wise (F-order). - */ -__global__ void tie_correction_kernel(const double* __restrict__ sorted_vals, - double* __restrict__ correction, - const int n_rows, const int n_cols) { - // Each block handles one column +#include "wilcoxon_block_reduce.cuh" +#include "wilcoxon_ovr_tie_walk.cuh" + +// Dense OVR rank kernel over sorted F-order columns; no rank matrix +// materialized. CRITICAL: `use_gmem` is required when n_groups exceeds +// shared-memory capacity. +__global__ void rank_sums_from_sorted_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { int col = blockIdx.x; if (col >= n_cols) return; - const double* sv = sorted_vals + (size_t)col * n_rows; + extern __shared__ double smem[]; - double local_sum = 0.0; - int tid = threadIdx.x; - - // Each thread processes positions where it detects END of a tie group - // Start from index 1, check if sv[i-1] != sv[i] (boundary detected) - // When at boundary, use binary search to find tie group size - for (int i = tid + 1; i <= n_rows; i += blockDim.x) { - // Detect boundary: either at the end, or value changed - bool at_boundary = (i == n_rows) || (sv[i] != sv[i - 1]); - - if (at_boundary) { - // Found end of tie group at position i-1 - // Binary search for start of this tie group - double val = sv[i - 1]; - int lo = 0, hi = i - 1; - while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { - lo = mid + 1; - } else { - hi = mid; - } - } - int tie_count = i - lo; - - // t^3 - t for this tie group - double t = (double)tie_count; - local_sum += t * t * t - t; + double* grp_sums; + if (use_gmem) { + grp_sums = rank_sums + (size_t)col; + } else { + grp_sums = smem; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; } + __syncthreads(); } - // Warp-level reduction using shuffle -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); - } + const float* sv = sorted_vals + (size_t)col * n_rows; + const int* si = sorted_row_idx + (size_t)col * n_rows; - // Cross-warp reduction using small shared memory - __shared__ double warp_sums[32]; - int lane = tid & 31; - int warp_id = tid >> 5; + int chunk = (n_rows + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > n_rows) my_end = n_rows; + + int acc_stride = use_gmem ? n_cols : 1; + double local_tie_sum = ovr_walk_tie_runs( + sv, si, group_codes, grp_sums, acc_stride, n_groups, my_start, my_end, + /*seg_floor=*/0, /*seg_ceil=*/n_rows, /*rank_offset=*/0.0, + compute_tie_corr); - if (lane == 0) { - warp_sums[warp_id] = local_sum; - } __syncthreads(); - // Final reduction in first warp - // Note: blockDim.x must be a multiple of 32 for correct warp reduction - if (tid < 32) { - double val = (tid < (blockDim.x >> 5)) ? warp_sums[tid] : 0.0; -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (tid == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - if (denom > 0) { - correction[col] = 1.0 - val / denom; - } else { - correction[col] = 1.0; - } + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; } } -} - -/** - * Kernel to compute average ranks for each column. - * Uses scipy.stats.rankdata 'average' method: ties get the average of the ranks - * they would span. - * - * Each block handles one column. Assumes input is sorted column-wise (F-order). - */ -__global__ void average_rank_kernel(const double* __restrict__ sorted_vals, - const int* __restrict__ sorter, - double* __restrict__ ranks, - const int n_rows, const int n_cols) { - // Each thread block handles one column - int col = blockIdx.x; - if (col >= n_cols) return; - - // Pointers to this column's data - const double* sv = sorted_vals + (size_t)col * n_rows; - const int* si = sorter + (size_t)col * n_rows; - double* rk = ranks + (size_t)col * n_rows; - - // Each thread processes multiple rows - for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { - double val = sv[i]; - - // Binary search for tie_start (first element equal to val) - int lo = 0, hi = i; - while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { - lo = mid + 1; - } else { - hi = mid; - } - } - int tie_start = lo; - - // Binary search for tie_end (last element equal to val) - lo = i; - hi = n_rows - 1; - while (lo < hi) { - int mid = (lo + hi + 1) / 2; - if (sv[mid] > val) { - hi = mid - 1; - } else { - lo = mid; - } - } - int tie_end = lo; - - // Average rank for ties: (start + end + 2) / 2 (1-based ranks) - double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; - // Write rank to original position - rk[si[i]] = avg_rank; + if (compute_tie_corr) { + int warp_buf_off = use_gmem ? 0 : n_groups; + double* warp_buf = smem + warp_buf_off; + double tie_sum = wilcoxon_block_sum(local_tie_sum, warp_buf); + if (threadIdx.x == 0) + tie_corr[col] = finalize_tie_corr(n_rows, tie_sum); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh new file mode 100644 index 00000000..98ea3f06 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -0,0 +1,323 @@ +#pragma once + +#include + +#include "wilcoxon_block_reduce.cuh" +#include "wilcoxon_fast_common.cuh" + +// Bitonic sort of power-of-two `n` floats in shared memory, ascending. +// Pad the tail with +INF before calling; any blockDim works. + +__device__ __forceinline__ void bitonic_sort_smem(float* s, int n) { + for (int k = 2; k <= n; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = threadIdx.x; i < n; i += blockDim.x) { + int ixj = i ^ j; + if (ixj > i) { + bool asc = ((i & k) == 0); + float a = s[i], b = s[ixj]; + if (asc ? (a > b) : (a < b)) { + s[i] = b; + s[ixj] = a; + } + } + } + __syncthreads(); + } + } +} + +// Sorted-array bounds over [lo, hi): lower is first >= v, upper first > v. +// Advanced `lo` exploits monotonic strides; global/shared arrays both work. + +__device__ __forceinline__ int sorted_lower_bound(const float* arr, int lo, + int hi, float v) { + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (arr[m] < v) + lo = m + 1; + else + hi = m; + } + return lo; +} + +__device__ __forceinline__ int sorted_upper_bound(const float* arr, int lo, + int hi, float v) { + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (arr[m] <= v) + lo = m + 1; + else + hi = m; + } + return lo; +} + +// Mid-rank of `v` in merged (ref, grp) arrays with incremental bounds. +// Also reports equal counts per array for tie correction. +struct OvoRank { + double mid_rank; + int n_eq_ref; + int n_eq_grp; +}; + +__device__ __forceinline__ OvoRank ovo_mid_rank(const float* ref, int n_ref, + const float* grp, int n_grp, + float v, int& ref_lb, + int& ref_ub, int& grp_lb, + int& grp_ub) { + int n_lt_ref = sorted_lower_bound(ref, ref_lb, n_ref, v); + ref_lb = n_lt_ref; + ref_ub = sorted_upper_bound(ref, ref_ub > n_lt_ref ? ref_ub : n_lt_ref, + n_ref, v); + int n_eq_ref = ref_ub - n_lt_ref; + + int n_lt_grp = sorted_lower_bound(grp, grp_lb, n_grp, v); + grp_lb = n_lt_grp; + grp_ub = sorted_upper_bound(grp, grp_ub > n_lt_grp ? grp_ub : n_lt_grp, + n_grp, v); + int n_eq_grp = grp_ub - n_lt_grp; + + OvoRank r; + r.mid_rank = (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + r.n_eq_ref = n_eq_ref; + r.n_eq_grp = n_eq_grp; + return r; +} + +// Amortized tie correction for sorted LARGE/HUGE groups. +// Only unique group values update the precomputed ref tie base. +__device__ __forceinline__ void compute_tie_delta_sorted_grp( + const float* ref_col, int n_ref, const float* grp_col, int n_grp, + double ref_base, double* warp_buf, double* out) { + double local = 0.0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + // run-start of a unique value in the sorted group + if (i == 0 || grp_col[i] != grp_col[i - 1]) { + float v = grp_col[i]; + int gub = sorted_upper_bound(grp_col, i + 1, n_grp, v); + double cg = (double)(gub - i); + int rlo = sorted_lower_bound(ref_col, 0, n_ref, v); + int rub = sorted_upper_bound(ref_col, rlo, n_ref, v); + double cr = (double)(rub - rlo); + double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0; + local += group_tie; + if (cr > 0.0) { + double combined = cr + cg; + double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0; + local += combined * combined * combined - combined - ref_tie - + group_tie; + } + } + } + double tie = wilcoxon_block_sum(local, warp_buf); + if (threadIdx.x == 0) + *out = finalize_tie_corr(n_ref + n_grp, ref_base + tie); +} + +// No-tie fast path: group-internal ranks collapse to the U closed form. +// Each unsorted group value binary-searches the sorted reference; no group +// sort. +__global__ void ovo_rank_dense_vs_ref_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, + int n_ref, int n_all_grp, int n_cols, int n_groups) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int n_grp = grp_offsets[grp + 1] - g_start; + if (n_grp == 0) { + if (threadIdx.x == 0) rank_sums[(size_t)grp * n_cols + col] = 0.0; + return; + } + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + + double local_sum = 0.0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + int n_lt = sorted_lower_bound(ref_col, 0, n_ref, v); + int n_eq = sorted_upper_bound(ref_col, n_lt, n_ref, v) - n_lt; + local_sum += (double)n_lt + 0.5 * (double)n_eq; + } + __shared__ double warp_buf[32]; + double total = wilcoxon_block_sum(local_sum, warp_buf); + if (threadIdx.x == 0) { + rank_sums[(size_t)grp * n_cols + col] = + total + (double)n_grp * ((double)n_grp + 1.0) / 2.0; + } +} + +// LARGE/HUGE rank kernel; LARGE smem-sorts, HUGE reads CUB-sorted groups. +// Post-sort mid-rank/tie body is shared and each group owns its output row. +template +__global__ void ovo_rank_sorted_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_in, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int large_padded, int skip_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int n_grp = grp_offsets[grp + 1] - g_start; + if (n_grp <= skip_n_grp_le) return; + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + __shared__ double warp_buf[32]; + const float* grp_col; + if constexpr (SMEM_SORT) { + extern __shared__ float grp_smem[]; + const float* src = grp_in + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = src[i]; + for (int i = n_grp + threadIdx.x; i < large_padded; i += blockDim.x) + grp_smem[i] = __int_as_float(0x7f800000); // +INF pad + __syncthreads(); + bitonic_sort_smem(grp_smem, large_padded); + grp_col = grp_smem; + } else { + (void)large_padded; + grp_col = + grp_in + (long long)col * n_all_grp + g_start; // CUB-presorted + } + + int ref_lb = 0, ref_ub = 0, grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + OvoRank r = ovo_mid_rank(ref_col, n_ref, grp_col, n_grp, grp_col[i], + ref_lb, ref_ub, grp_lb, grp_ub); + local_sum += r.mid_rank; + } + double total = wilcoxon_block_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + // grp_col is sorted: amortize the ref tie contribution via the precomputed + // base instead of rescanning the ref per group. + compute_tie_delta_sorted_grp(ref_col, n_ref, grp_col, n_grp, + ref_tie_sums[col], warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// MEDIUM tie helper: sorted-reference contribution, one block per column. +// Rank kernels add only group-only/ref-overlap deltas. +__global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, + double* __restrict__ ref_tie_sums, int n_ref, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* ref_col = ref_sorted + (long long)col * n_ref; + + double local_tie = 0.0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int cnt = sorted_upper_bound(ref_col, i + 1, n_ref, v) - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + __shared__ double warp_buf[32]; + double total = wilcoxon_block_sum(local_tie, warp_buf); + if (threadIdx.x == 0) ref_tie_sums[col] = total; +} + +// MEDIUM fused kernel: ref binary searches plus in-group scan over smem values. +// Tie correction starts from ref_tie_sums[col] and adds only group deltas. +__global__ void ovo_rank_medium_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le, int max_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > max_n_grp_le) return; + + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + max_n_grp_le * sizeof(float)); + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + __syncthreads(); + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + + int n_lt_ref = sorted_lower_bound(ref_col, 0, n_ref, v); + int n_eq_ref = + sorted_upper_bound(ref_col, n_lt_ref, n_ref, v) - n_lt_ref; + + int n_lt_grp = 0; + int n_eq_grp = 0; + bool first_in_grp = true; + for (int j = 0; j < n_grp; ++j) { + float w = grp_smem[j]; + if (w < v) ++n_lt_grp; + if (w == v) { + ++n_eq_grp; + if (j < i) first_in_grp = false; + } + } + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && first_in_grp) { + double cg = (double)n_eq_grp; + double cr = (double)n_eq_ref; + double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0; + local_tie_delta += group_tie; + if (cr > 0.0) { + double combined = cr + cg; + double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0; + local_tie_delta += combined * combined * combined - combined - + ref_tie - group_tie; + } + } + } + + double total = wilcoxon_block_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = wilcoxon_block_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) + tie_corr[grp * n_cols + col] = + finalize_tie_corr(n_ref + n_grp, ref_tie_sums[col] + tie_delta); +} + +// WARP/SMALL tiers were removed; MEDIUM now covers all groups <= +// OVO_MEDIUM_MAX. Restore notes live in +// .claude/wilcoxon-warp-small-tiers-removed.md. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index d25f7d0f..01cbee41 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -1,70 +1,1047 @@ #include +#include + +#include +#include +#include +#include + #include "../nb_types.h" #include "kernels_wilcoxon.cuh" +#include "wilcoxon_fast_common.cuh" +#include "kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovo_kernels.cuh" using namespace nb::literals; -// Constants for kernel launch configuration -constexpr int WARP_SIZE = 32; -constexpr int MAX_THREADS_PER_BLOCK = 512; +static void launch_ovr_rank_dense_streaming( + const float* block, const int* group_codes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols, cudaStream_t upstream_stream) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + + DenseColumnBatchPlan batches = plan_dense_column_batches( + n_rows, n_cols, sub_batch_cols, SAFE_BATCH_NNZ, "Dense OVR sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + size_t sub_items = batches.max_items; + int sub_items_i32 = checked_cub_items(sub_items, "Dense OVR sub-batch"); + + size_t cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(sub_items_i32, sub_batch_cols); + + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); + + ScopedCudaEvent inputs_ready(cudaEventDisableTiming); + inputs_ready.record(upstream_stream); + for (int i = 0; i < n_streams; ++i) { + cuda_check(cudaStreamWaitEvent(streams[i], inputs_ready.get(), 0), + "wait on inputs_ready (dense OVR)"); + } + + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = checked_int_product((size_t)n_rows, (size_t)sb_cols, + "Dense OVR active sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); -static inline int round_up_to_warp(int n) { - int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; + const float* keys_in = block + (size_t)col * n_rows; + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, stream, "dense OVR segmented sort"); + + if (use_gmem) { + cuda_check(cudaMemsetAsync( + buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream), + "dense OVR gmem rank_sums memset"); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + cuda_check( + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense OVR rank_sums D2D copy"); + if (compute_tie_corr) { + cuda_check(cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream), + "dense OVR tie_corr D2D copy"); + } + + col += sb_cols; + ++batch_idx; + } + + sync_streams(streams, "dense OVR streaming rank"); } -static inline void launch_tie_correction(const double* sorted_vals, - double* correction, int n_rows, - int n_cols, cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - tie_correction_kernel<<>>(sorted_vals, correction, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(tie_correction_kernel); +// Host-streaming dense OVR: pinned multi-stream batches into F-order device +// slabs. F-order copies contiguous; C-order uses 2D copy; stats accumulate in +// f64. +template +static void launch_ovr_rank_dense_host_streaming( + const T* h_X, bool f_order, const int* group_codes, double* rank_sums, + double* tie_corr, double* group_sums, double* group_nnz, double* total_sums, + double* total_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_nnz, bool compute_totals, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + const bool compute_stats = group_sums != nullptr; + compute_nnz = compute_nnz && (group_nnz != nullptr); + compute_totals = compute_stats && compute_totals && (total_sums != nullptr); + // F-order float32 input feeds the sort directly (no cast/transpose buffer). + const bool fast_keys = f_order && std::is_same::value; + + DenseColumnBatchPlan batches = + plan_dense_column_batches(n_rows, n_cols, sub_batch_cols, + SAFE_BATCH_NNZ, "Dense host OVR sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + size_t sub_items = batches.max_items; + int sub_items_i32 = + checked_cub_items(sub_items, "Dense host OVR sub-batch"); + size_t cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(sub_items_i32, sub_batch_cols); + + // Clamp stream count to device memory budget so a large matrix shrinks the + // pipeline rather than OOMing on per-stream sort scratch. + size_t per_stream_bytes = + sub_items * (sizeof(T) + (fast_keys ? 0 : sizeof(float)) + + sizeof(float) + 2 * sizeof(int)) + + cub_temp_bytes + (size_t)(sub_batch_cols + 1) * sizeof(int) + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + (size_t)sub_batch_cols * sizeof(double) + + (compute_stats ? (size_t)n_groups * sub_batch_cols * sizeof(double) + : 0) + + (compute_nnz ? (size_t)n_groups * sub_batch_cols * sizeof(double) : 0) + + (compute_totals ? (size_t)sub_batch_cols * sizeof(double) : 0) + + (compute_totals && compute_nnz ? (size_t)sub_batch_cols * sizeof(double) + : 0); + n_streams = clamp_streams_by_budget(n_streams, per_stream_bytes, + rmm_available_device_bytes(0.8)); + + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; + // Best-effort pin for faster async H2D; on failure proceed unpinned. + HostRegisterGuard _pin(const_cast(h_X), + (size_t)n_rows * n_cols * sizeof(T), 0, + /*best_effort=*/true); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + struct StreamBuf { + T* d_stg; + float* block_f32; + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_nnz; + double* sub_total_sums; + double* sub_total_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + bufs[s].d_stg = pool.alloc(sub_items); + bufs[s].block_f32 = fast_keys ? nullptr : pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_group_sums = + compute_stats + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_total_sums = + compute_totals ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(sub_batch_cols) + : nullptr; + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = checked_int_product((size_t)n_rows, (size_t)sb_cols, + "Dense host OVR active sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + // H2D the column window (overlaps the prior batch rank). + if (f_order) { + cudaMemcpyAsync(buf.d_stg, h_X + (size_t)col * n_rows, + (size_t)sb_items * sizeof(T), + cudaMemcpyHostToDevice, stream); + } else { + cudaMemcpy2DAsync(buf.d_stg, (size_t)sb_cols * sizeof(T), h_X + col, + (size_t)n_cols * sizeof(T), + (size_t)sb_cols * sizeof(T), n_rows, + cudaMemcpyHostToDevice, stream); + } + + const float* keys_in; + if (fast_keys) { + keys_in = reinterpret_cast(buf.d_stg); + } else { + // grid-stride kernel: bounded grid covers any sb_items (<=INT_MAX) + // with no launch-math overflow. + const unsigned int grid = (unsigned int)std::min( + ((size_t)sb_items + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE, + 65535u); + dense_block_to_f32_kernel<<>>( + buf.d_stg, buf.block_f32, n_rows, sb_cols, f_order); + CUDA_CHECK_LAST_ERROR(dense_block_to_f32_kernel); + keys_in = buf.block_f32; + } + + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, stream, "dense host OVR segmented sort"); + + // gmem rank mode atomicAdds without self-zeroing and the buffer is + // reused round-robin, so zero it first. + if (use_gmem) { + cuda_check(cudaMemsetAsync( + buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream), + "dense host OVR gmem rank_sums memset"); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + cuda_check( + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense host OVR rank_sums D2D copy"); + if (compute_tie_corr) { + cuda_check(cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream), + "dense host OVR tie_corr D2D copy"); + } + + // Group sums (+nnz) for means/pts, f64 from native staging (matches + // the Aggregate path). + if (compute_stats) { + cudaMemsetAsync(buf.sub_group_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + if (compute_nnz) { + cudaMemsetAsync(buf.sub_group_nnz, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + if (compute_totals) { + cudaMemsetAsync(buf.sub_total_sums, 0, sb_cols * sizeof(double), + stream); + if (compute_nnz) { + cudaMemsetAsync(buf.sub_total_nnz, 0, + sb_cols * sizeof(double), stream); + } + } + dense_group_accumulate_kernel + <<>>( + buf.d_stg, group_codes, buf.sub_group_sums, + compute_nnz ? buf.sub_group_nnz : buf.sub_group_sums, + buf.sub_total_sums, + compute_nnz ? buf.sub_total_nnz : buf.sub_total_sums, + n_rows, sb_cols, n_groups, f_order, compute_nnz, + compute_totals); + CUDA_CHECK_LAST_ERROR(dense_group_accumulate_kernel); + scatter_cols_2d(group_sums + col, buf.sub_group_sums, n_groups, + n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(group_nnz + col, buf.sub_group_nnz, n_groups, + n_cols, sb_cols, stream); + } + if (compute_totals) { + cudaMemcpyAsync(total_sums + col, buf.sub_total_sums, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + if (compute_nnz) { + cudaMemcpyAsync(total_nnz + col, buf.sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } + } + + col += sb_cols; + ++batch_idx; + } + + sync_streams(streams, "dense host OVR streaming"); } -static inline void launch_average_rank(const double* sorted_vals, - const int* sorter, double* ranks, - int n_rows, int n_cols, - cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - average_rank_kernel<<>>(sorted_vals, sorter, ranks, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(average_rank_kernel); +static void launch_ovo_rank_dense_tiered_unsorted_ref( + const float* ref_data, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + cudaStream_t upstream_stream) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + + std::vector h_offsets(n_groups + 1); + cuda_check(cudaStreamSynchronize(upstream_stream), + "dense OVO sync before offsets D2H"); + cuda_check(cudaMemcpy(h_offsets.data(), grp_offsets, + (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost), + "dense OVO group offsets D2H"); + auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool run_large = t1.above_medium && t1.run_large; + bool run_huge = t1.above_medium && !run_large; + + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_offsets.data(), n_groups, OVO_MEDIUM_MAX); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + DenseColumnBatchPlan batches = plan_dense_column_batches( + std::max(n_ref, n_all_grp), n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "Dense OVO sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "Dense OVO reference sub-batch"); + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "Dense OVO group sub-batch"); + + size_t grp_cub_temp_bytes = 0; + if (run_huge) { + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "Dense OVO group segment count"); + grp_cub_temp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); + } + size_t ref_cub_temp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); + + { + size_t per_stream = + sub_ref_items * sizeof(float) + + (size_t)(sub_batch_cols + 1) * sizeof(int) + ref_cub_temp_bytes + + (run_huge ? sub_grp_items * sizeof(float) : 0) + + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) + : 0) + + (run_huge ? grp_cub_temp_bytes : 0) + + (compute_tie_corr ? (size_t)sub_batch_cols * sizeof(double) : 0) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double); + n_streams = clamp_streams_by_budget(n_streams, per_stream, + rmm_available_device_bytes(0.8)); + } + + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); + + ScopedCudaEvent inputs_ready(cudaEventDisableTiming); + inputs_ready.record(upstream_stream); + for (int i = 0; i < n_streams; ++i) { + cuda_check(cudaStreamWaitEvent(streams[i], inputs_ready.get(), 0), + "wait on inputs_ready (dense OVO)"); + } + int* d_sort_group_ids = nullptr; + if (run_huge) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cuda_check(cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice), + "dense OVO sort group ids H2D"); + } + + struct StreamBuf { + float* ref_sorted; + int* ref_seg_offsets; + uint8_t* ref_cub_temp; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* grp_cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); + bufs[s].grp_cub_temp = + run_huge ? pool.alloc(grp_cub_temp_bytes) : nullptr; + // All tiers share the ref tie base, so allocate whenever correcting. + bufs[s].ref_tie_sums = + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (run_huge) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = checked_int_product((size_t)n_sort_groups, + (size_t)sub_batch_cols, + "Dense OVO group segment buffer"); + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "Dense OVO active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "Dense OVO active group sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + const float* ref_sub = ref_data + (size_t)col * n_ref; + const float* grp_sub = grp_data + (size_t)col * n_all_grp; + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + cub_segmented_sortkeys(buf.ref_cub_temp, ref_cub_temp_bytes, ref_sub, + buf.ref_sorted, sb_ref_items_actual, sb_cols, + buf.ref_seg_offsets, buf.ref_seg_offsets + 1, + stream, "dense OVO ref segmented sort"); + ref_sub = buf.ref_sorted; + + OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.grp_cub_temp}; + ovo_dispatch_tiers(ref_sub, grp_sub, grp_offsets, t1, sc, + d_sort_group_ids, n_sort_groups, grp_cub_temp_bytes, + sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + + cuda_check( + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense OVO rank_sums D2D copy"); + if (compute_tie_corr) { + cuda_check( + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense OVO tie_corr D2D copy"); + } + + col += sb_cols; + ++batch_idx; + } + + sync_streams(streams, "dense OVO tiered rank"); +} + +template +static void launch_ovo_rank_dense_host_streaming( + const T* h_X, bool f_order, const int* h_ref_row_ids, + const int* h_grp_row_ids, const int* h_grp_offsets, double* rank_sums, + double* tie_corr, double* group_sums, double* group_sum_sq, + double* group_nnz, int n_full_rows, int n_ref, int n_all_grp, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, bool compute_nnz, + bool compute_stats, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + if (compute_stats && n_groups_stats != n_groups + 1) { + throw std::runtime_error( + "dense OVO host stats require n_groups_stats == n_groups + 1"); + } + if (h_grp_offsets[0] != 0 || h_grp_offsets[n_groups] != n_all_grp) { + throw std::runtime_error( + "dense OVO host group offsets must span n_all_grp"); + } + + auto tier_plan = make_ovo_tier_plan(h_grp_offsets, n_groups); + int max_grp_size = tier_plan.max_grp_size; + bool run_large = tier_plan.above_medium && tier_plan.run_large; + bool run_huge = tier_plan.above_medium && !run_large; + + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, OVO_MEDIUM_MAX); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + DenseColumnBatchPlan batches = plan_dense_column_batches( + std::max(n_ref, n_all_grp), n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "Dense host OVO sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "Dense host OVO reference sub-batch"); + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "Dense host OVO group sub-batch"); + constexpr bool fast_keys = std::is_same::value; + int n_stats_rows = n_groups + 1; + + size_t grp_cub_temp_bytes = 0; + if (run_huge) { + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "Dense host OVO group segment count"); + grp_cub_temp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); + } + size_t ref_cub_temp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); + + { + size_t native_items = sub_ref_items + sub_grp_items; + size_t per_stream = + native_items * sizeof(T) + + (fast_keys ? 0 : native_items * sizeof(float)) + + sub_ref_items * sizeof(float) + + (size_t)(sub_batch_cols + 1) * sizeof(int) + ref_cub_temp_bytes + + (run_huge ? sub_grp_items * sizeof(float) : 0) + + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) + : 0) + + (run_huge ? grp_cub_temp_bytes : 0) + + (compute_tie_corr ? (size_t)sub_batch_cols * sizeof(double) : 0) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + (compute_stats + ? 2 * (size_t)n_stats_rows * sub_batch_cols * sizeof(double) + : 0) + + (compute_nnz + ? (size_t)n_stats_rows * sub_batch_cols * sizeof(double) + : 0); + n_streams = clamp_streams_by_budget(n_streams, per_stream, + rmm_available_device_bytes(0.8)); + } + + RmmScratchPool pool; + PinnedRing stage(n_streams, batches.max_items); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + int* d_grp_offsets = pool.alloc(n_groups + 1); + cuda_check(cudaMemcpy(d_grp_offsets, h_grp_offsets, + (size_t)(n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice), + "dense host OVO offsets H2D"); + + int* d_sort_group_ids = nullptr; + if (run_huge) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cuda_check(cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice), + "dense host OVO sort group ids H2D"); + } + + int* d_grp_codes = nullptr; + if (compute_stats) { + std::vector h_grp_codes(n_all_grp, -1); + for (int g = 0; g < n_groups; g++) { + int begin = h_grp_offsets[g]; + int end = h_grp_offsets[g + 1]; + if (begin < 0 || end < begin || end > n_all_grp) { + throw std::runtime_error( + "dense OVO host group offsets are invalid"); + } + std::fill(h_grp_codes.begin() + begin, h_grp_codes.begin() + end, + g); + } + d_grp_codes = pool.alloc(n_all_grp); + cuda_check( + cudaMemcpy(d_grp_codes, h_grp_codes.data(), + (size_t)n_all_grp * sizeof(int), cudaMemcpyHostToDevice), + "dense host OVO group codes H2D"); + } + + struct StreamBuf { + T* ref_native; + T* grp_native; + float* ref_f32; + float* grp_f32; + float* ref_sorted; + int* ref_seg_offsets; + uint8_t* ref_cub_temp; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* grp_cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_sum_sq; + double* sub_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_native = pool.alloc(sub_ref_items); + bufs[s].grp_native = pool.alloc(sub_grp_items); + bufs[s].ref_f32 = + fast_keys ? nullptr : pool.alloc(sub_ref_items); + bufs[s].grp_f32 = + fast_keys ? nullptr : pool.alloc(sub_grp_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); + bufs[s].grp_cub_temp = + run_huge ? pool.alloc(grp_cub_temp_bytes) : nullptr; + bufs[s].ref_tie_sums = + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (run_huge) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = checked_int_product((size_t)n_sort_groups, + (size_t)sub_batch_cols, + "Dense host OVO group segments"); + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + bufs[s].sub_group_sums = + compute_stats + ? pool.alloc((size_t)n_stats_rows * sub_batch_cols) + : nullptr; + bufs[s].sub_group_sum_sq = + compute_stats + ? pool.alloc((size_t)n_stats_rows * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz + ? pool.alloc((size_t)n_stats_rows * sub_batch_cols) + : nullptr; + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + int tpb = UTIL_BLOCK_SIZE; + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "Dense host OVO active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "Dense host OVO active group sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + stage.wait(s); + T* h_ref_stage = stage.template get<0>(s); + T* h_grp_stage = stage.template get<1>(s); + + host_materialize_dense_rows_window(h_X, f_order, n_full_rows, n_cols, + h_ref_row_ids, n_ref, col, sb_cols, + h_ref_stage); + host_materialize_dense_rows_window(h_X, f_order, n_full_rows, n_cols, + h_grp_row_ids, n_all_grp, col, + sb_cols, h_grp_stage); + + cuda_check(cudaMemcpyAsync(buf.ref_native, h_ref_stage, + (size_t)sb_ref_items_actual * sizeof(T), + cudaMemcpyHostToDevice, stream), + "dense host OVO ref H2D"); + cuda_check(cudaMemcpyAsync(buf.grp_native, h_grp_stage, + (size_t)sb_grp_items_actual * sizeof(T), + cudaMemcpyHostToDevice, stream), + "dense host OVO group H2D"); + stage.record(s, stream); + + const float* ref_sub; + const float* grp_sub; + if (fast_keys) { + ref_sub = reinterpret_cast(buf.ref_native); + grp_sub = reinterpret_cast(buf.grp_native); + } else { + unsigned int ref_grid = (unsigned int)std::min( + ((size_t)sb_ref_items_actual + UTIL_BLOCK_SIZE - 1) / + UTIL_BLOCK_SIZE, + 65535u); + dense_block_to_f32_kernel + <<>>( + buf.ref_native, buf.ref_f32, n_ref, sb_cols, true); + CUDA_CHECK_LAST_ERROR(dense_block_to_f32_kernel); + unsigned int grp_grid = (unsigned int)std::min( + ((size_t)sb_grp_items_actual + UTIL_BLOCK_SIZE - 1) / + UTIL_BLOCK_SIZE, + 65535u); + dense_block_to_f32_kernel + <<>>( + buf.grp_native, buf.grp_f32, n_all_grp, sb_cols, true); + CUDA_CHECK_LAST_ERROR(dense_block_to_f32_kernel); + ref_sub = buf.ref_f32; + grp_sub = buf.grp_f32; + } + + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + cub_segmented_sortkeys(buf.ref_cub_temp, ref_cub_temp_bytes, ref_sub, + buf.ref_sorted, sb_ref_items_actual, sb_cols, + buf.ref_seg_offsets, buf.ref_seg_offsets + 1, + stream, "dense host OVO ref segmented sort"); + ref_sub = buf.ref_sorted; + + OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.grp_cub_temp}; + ovo_dispatch_tiers(ref_sub, grp_sub, d_grp_offsets, tier_plan, sc, + d_sort_group_ids, n_sort_groups, grp_cub_temp_bytes, + sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + + cuda_check( + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense host OVO rank_sums D2D copy"); + if (compute_tie_corr) { + cuda_check( + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream), + "dense host OVO tie_corr D2D copy"); + } + + if (compute_stats) { + cuda_check( + cudaMemsetAsync(buf.sub_group_sums, 0, + (size_t)n_stats_rows * sb_cols * sizeof(double), + stream), + "dense host OVO group sums memset"); + cuda_check( + cudaMemsetAsync(buf.sub_group_sum_sq, 0, + (size_t)n_stats_rows * sb_cols * sizeof(double), + stream), + "dense host OVO group sumsq memset"); + if (compute_nnz) { + cuda_check(cudaMemsetAsync( + buf.sub_group_nnz, 0, + (size_t)n_stats_rows * sb_cols * sizeof(double), + stream), + "dense host OVO group nnz memset"); + } + dense_ovo_group_stats_kernel<<>>( + buf.ref_native, buf.grp_native, d_grp_codes, buf.sub_group_sums, + buf.sub_group_sum_sq, + compute_nnz ? buf.sub_group_nnz : buf.sub_group_sums, n_ref, + n_all_grp, sb_cols, n_groups, compute_nnz); + CUDA_CHECK_LAST_ERROR(dense_ovo_group_stats_kernel); + scatter_cols_2d(group_sums + col, buf.sub_group_sums, n_stats_rows, + n_cols, sb_cols, stream); + scatter_cols_2d(group_sum_sq + col, buf.sub_group_sum_sq, + n_stats_rows, n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(group_nnz + col, buf.sub_group_nnz, + n_stats_rows, n_cols, sb_cols, stream); + } + } + + col += sb_cols; + ++batch_idx; + } + + sync_streams(streams, "dense host OVO streaming"); +} + +template +static void def_ovr_rank_dense_host_streaming(nb::module_& m) { + m.def( + "ovr_rank_dense_host_streaming", + [](HostArray X, gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, + gpu_array_c group_sums, + gpu_array_c group_nnz, + gpu_array_c total_sums, + gpu_array_c total_nnz, int n_groups, + bool compute_tie_corr, bool compute_nnz, bool compute_stats, + bool compute_totals, int sub_batch_cols) { + int n_rows = (int)X.shape(0); + int n_cols = (int)X.shape(1); + nb_require((int)group_codes.shape(0) == n_rows, + "ovr_rank_host: group_codes length must be n_rows"); + nb_require( + (int)rank_sums.shape(0) == n_groups && + (int)rank_sums.shape(1) == n_cols, + "ovr_rank_host: rank_sums shape must be (n_groups, n_cols)"); + nb_require((int)tie_corr.shape(0) == n_cols, + "ovr_rank_host: tie_corr length must be n_cols"); + launch_ovr_rank_dense_host_streaming( + X.data(), FOrder, group_codes.data(), rank_sums.data(), + tie_corr.data(), compute_stats ? group_sums.data() : nullptr, + compute_nnz ? group_nnz.data() : nullptr, + compute_totals ? total_sums.data() : nullptr, + (compute_totals && compute_nnz) ? total_nnz.data() : nullptr, + n_rows, n_cols, n_groups, compute_tie_corr, compute_nnz, + compute_totals, sub_batch_cols); + }, + "X"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, "group_sums"_a, + "group_nnz"_a, "total_sums"_a, "total_nnz"_a, nb::kw_only(), + "n_groups"_a, "compute_tie_corr"_a, "compute_nnz"_a, "compute_stats"_a, + "compute_totals"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); +} + +template +static void def_ovo_rank_dense_host_streaming(nb::module_& m) { + m.def( + "ovo_rank_dense_host_streaming", + [](HostArray X, host_array ref_row_ids, + host_array grp_row_ids, host_array grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, + gpu_array_c group_sums, + gpu_array_c group_sum_sq, + gpu_array_c group_nnz, int n_groups, + bool compute_tie_corr, bool compute_nnz, bool compute_stats, + int sub_batch_cols) { + int n_full_rows = (int)X.shape(0); + int n_cols = (int)X.shape(1); + int n_ref = (int)ref_row_ids.shape(0); + int n_all_grp = (int)grp_row_ids.shape(0); + nb_require((int)grp_offsets.shape(0) == n_groups + 1, + "ovo_rank_host: grp_offsets length must be n_groups+1"); + nb_require(rank_sums.ndim() == 2 && tie_corr.ndim() == 2, + "ovo_rank_host: rank_sums/tie_corr must be 2D"); + nb_require((int)rank_sums.shape(0) == n_groups && + (int)rank_sums.shape(1) == n_cols, + "ovo_rank_host: rank_sums shape must be " + "(n_groups, n_cols)"); + nb_require((int)tie_corr.shape(0) == n_groups && + (int)tie_corr.shape(1) == n_cols, + "ovo_rank_host: tie_corr shape must be " + "(n_groups, n_cols)"); + int n_groups_stats = compute_stats ? (int)group_sums.shape(0) : 0; + if (compute_stats) { + nb_require(group_sums.ndim() == 2 && group_sum_sq.ndim() == 2, + "ovo_rank_host: stats outputs must be 2D"); + nb_require(n_groups_stats == n_groups + 1 && + (int)group_sums.shape(1) == n_cols, + "ovo_rank_host: group_sums shape must be " + "(n_groups+1, n_cols)"); + nb_require((int)group_sum_sq.shape(0) == n_groups + 1 && + (int)group_sum_sq.shape(1) == n_cols, + "ovo_rank_host: group_sum_sq shape must be " + "(n_groups+1, n_cols)"); + if (compute_nnz) { + nb_require(group_nnz.ndim() == 2 && + (int)group_nnz.shape(0) == n_groups + 1 && + (int)group_nnz.shape(1) == n_cols, + "ovo_rank_host: group_nnz shape must be " + "(n_groups+1, n_cols)"); + } + } + launch_ovo_rank_dense_host_streaming( + X.data(), FOrder, ref_row_ids.data(), grp_row_ids.data(), + grp_offsets.data(), rank_sums.data(), tie_corr.data(), + compute_stats ? group_sums.data() : nullptr, + compute_stats ? group_sum_sq.data() : nullptr, + compute_nnz ? group_nnz.data() : nullptr, n_full_rows, n_ref, + n_all_grp, n_cols, n_groups, n_groups_stats, compute_tie_corr, + compute_nnz, compute_stats, sub_batch_cols); + }, + "X"_a, "ref_row_ids"_a, "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, "group_sums"_a, "group_sum_sq"_a, "group_nnz"_a, + nb::kw_only(), "n_groups"_a, "compute_tie_corr"_a, "compute_nnz"_a, + "compute_stats"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); } template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - // Tie correction kernel + def_ovr_rank_dense_host_streaming, + false>(m); + def_ovr_rank_dense_host_streaming, + true>(m); + def_ovr_rank_dense_host_streaming, false>(m); + def_ovr_rank_dense_host_streaming, true>(m); + def_ovo_rank_dense_host_streaming, + false>(m); + def_ovo_rank_dense_host_streaming, + true>(m); + def_ovo_rank_dense_host_streaming, false>(m); + def_ovo_rank_dense_host_streaming, true>(m); + m.def( - "tie_correction", - [](gpu_array_f sorted_vals, - gpu_array correction, int n_rows, int n_cols, + "ovo_rank_dense_tiered_unsorted_ref", + [](gpu_array_f ref_data, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { - launch_tie_correction(sorted_vals.data(), correction.data(), n_rows, - n_cols, (cudaStream_t)stream); + nb_require(ref_data.ndim() == 2 && grp_data.ndim() == 2 && + rank_sums.ndim() == 2 && tie_corr.ndim() == 2 && + grp_offsets.ndim() == 1, + "ovo_rank: data/outputs must be 2D, grp_offsets 1D"); + nb_require((int)ref_data.shape(0) == n_ref && + (int)ref_data.shape(1) == n_cols, + "ovo_rank: ref_data shape must be (n_ref, n_cols)"); + nb_require((int)grp_data.shape(0) == n_all_grp && + (int)grp_data.shape(1) == n_cols, + "ovo_rank: grp_data shape must be (n_all_grp, n_cols)"); + nb_require((int)grp_offsets.shape(0) >= n_groups + 1, + "ovo_rank: grp_offsets length must be >= n_groups + 1"); + nb_require((int)rank_sums.shape(0) == n_groups && + (int)rank_sums.shape(1) == n_cols, + "ovo_rank: rank_sums shape must be (n_groups, n_cols)"); + nb_require((int)tie_corr.shape(0) == n_groups && + (int)tie_corr.shape(1) == n_cols, + "ovo_rank: tie_corr shape must be (n_groups, n_cols)"); + launch_ovo_rank_dense_tiered_unsorted_ref( + ref_data.data(), grp_data.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols, + (cudaStream_t)stream); }, - "sorted_vals"_a, "correction"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, + "ref_data"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, "stream"_a = 0); - // Average rank kernel m.def( - "average_rank", - [](gpu_array_f sorted_vals, - gpu_array_f sorter, - gpu_array_f ranks, int n_rows, int n_cols, + "ovr_rank_dense_streaming", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { - launch_average_rank(sorted_vals.data(), sorter.data(), ranks.data(), - n_rows, n_cols, (cudaStream_t)stream); + nb_require(block.ndim() == 2 && rank_sums.ndim() == 2 && + group_codes.ndim() == 1 && tie_corr.ndim() == 1, + "ovr_rank: block/rank_sums 2D, group_codes/tie_corr 1D"); + nb_require( + (int)block.shape(0) == n_rows && (int)block.shape(1) == n_cols, + "ovr_rank: block shape must be (n_rows, n_cols)"); + nb_require((int)group_codes.shape(0) == n_rows, + "ovr_rank: group_codes length must be n_rows"); + nb_require((int)rank_sums.shape(0) == n_groups && + (int)rank_sums.shape(1) == n_cols, + "ovr_rank: rank_sums shape must be (n_groups, n_cols)"); + nb_require((int)tie_corr.shape(0) == n_cols, + "ovr_rank: tie_corr length must be n_cols"); + launch_ovr_rank_dense_streaming( + block.data(), group_codes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols, (cudaStream_t)stream); }, - "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "stream"_a = 0); + "block"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS, "stream"_a = 0); } NB_MODULE(_wilcoxon_cuda, m) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh new file mode 100644 index 00000000..156648c7 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_block_reduce.cuh @@ -0,0 +1,38 @@ +#pragma once + +#include + +// Sum `v` across the 32 lanes of a warp via shuffle-down; result on lane 0. +__device__ __forceinline__ double warp_reduce_sum(double v) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + return v; +} + +// Block-wide sum of `val` using one shared double per warp. +// Result is returned on thread 0; other threads get 0.0. +__device__ __forceinline__ double wilcoxon_block_sum(double val, + double* warp_buf) { + val = warp_reduce_sum(val); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; + return warp_reduce_sum(v); + } + return 0.0; +} + +// Final tie-correction factor: 1 - sum(t^3 - t) / (n^3 - n), or 1.0 when the +// ranking population n_total is too small for a correction. +__device__ __forceinline__ double finalize_tie_corr(int n_total, + double tie_sum) { + double dn = (double)n_total; + double denom = dn * dn * dn - dn; + return (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh new file mode 100644 index 00000000..8e578c13 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -0,0 +1,173 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include + +#include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR +#include "../rmm_scratch.h" // rmm_allocate, RmmScratchPool, ScopedCudaBuffer +#include "../sparse_extract/sparse_extract.cuh" // csr_extract_dense* kernels +#include "../streaming/streaming.cuh" + +constexpr int WARP_SIZE = 32; +constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int N_STREAMS = 4; +constexpr int SUB_BATCH_COLS = 64; +constexpr int BEGIN_BIT = 0; +constexpr int END_BIT = 32; +// Scratch slots for warp-level reduction (one slot per warp, 32 warps max). +constexpr int WARP_REDUCE_BUF = 32; +// MEDIUM band cap: groups up to this size use unsorted O(n^2) in-group rank. +// Tier dispatch: make_ovo_tier_plan. +constexpr int OVO_MEDIUM_MAX = 512; +// LARGE band cap (fused smem-sort kernel); beyond it -> HUGE (CUB segmented +// sort). +constexpr int OVO_LARGE_MAX = 2500; +// Per-stream dense slab budget: 128M f32 items plus sorted copy ~= 1GB/stream. +// Sub-batching keeps (n_g * eff_sb_cols) within this. +constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; + +// Budget-aware OVO-host pack sizing for fixed per-stream scratch. +// Reserves dense/sorted slabs plus rank/tie/seg/CUB headroom. +constexpr size_t OVO_PACK_FIXED_PER_STREAM = + 4 * GROUP_DENSE_BUDGET_ITEMS * sizeof(float); // ~2 GB +// Floor for the budget-derived pack-nnz cap: avoid pathological over-splitting +// into thousands of tiny packs when device memory is very tight. +constexpr size_t OVO_MIN_PACK_NNZ = 64 * 1024 * 1024; // 64M nnz + +// H2D staging-ring slot cap: keeps page-locked footprint bounded per row-block. +// 32M nnz was the best compromise across small and multi-billion-nnz scales. +constexpr size_t STAGE_RING_NNZ_CAP = 32 * 1024 * 1024; + +// Query CUB segmented-radix-sort scratch size. Float keys, int values/offsets. +static inline size_t cub_segmented_sortkeys_temp_bytes(int num_items, + int num_segments) { + size_t bytes = 0; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cuda_check(cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, bytes, fk, fk, num_items, num_segments, doff, + doff + 1, BEGIN_BIT, END_BIT), + "CUB SortKeys temp-size query"); + return bytes; +} + +template +static inline size_t cub_segmented_sortpairs_temp_bytes(int num_items, + int num_segments) { + size_t bytes = 0; + auto* fk = reinterpret_cast(1); + auto* v = reinterpret_cast(1); + auto* off = reinterpret_cast(1); + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, bytes, fk, fk, v, v, num_items, num_segments, off, + off + 1, BEGIN_BIT, END_BIT), + "CUB SortPairs temp-size query"); + return bytes; +} + +// Launch wrappers. begin/end offset arrays may be contiguous (off, off+1) or +// distinct (starts, ends). +static inline void cub_segmented_sortkeys( + void* d_temp, size_t temp_bytes, const float* keys_in, float* keys_out, + int num_items, int num_segments, const int* begin_offsets, + const int* end_offsets, cudaStream_t stream, const char* what) { + cuda_check( + cub::DeviceSegmentedRadixSort::SortKeys( + d_temp, temp_bytes, keys_in, keys_out, num_items, num_segments, + begin_offsets, end_offsets, BEGIN_BIT, END_BIT, stream), + what); +} + +template +static inline void cub_segmented_sortpairs( + void* d_temp, size_t temp_bytes, const float* keys_in, float* keys_out, + const ValT* vals_in, ValT* vals_out, int num_items, int num_segments, + const int* begin_offsets, const int* end_offsets, cudaStream_t stream, + const char* what) { + cuda_check(cub::DeviceSegmentedRadixSort::SortPairs( + d_temp, temp_bytes, keys_in, keys_out, vals_in, vals_out, + num_items, num_segments, begin_offsets, end_offsets, + BEGIN_BIT, END_BIT, stream), + what); +} + +// Universal CUDA static per-block shared-memory floor. +// Safe fallback if the device query fails. +constexpr size_t WILCOXON_FALLBACK_SMEM_PER_BLOCK = 48 * 1024; + +// CRITICAL: cached per-device smem limit drives every smem/gmem/tier decision. +// Do not hardcode thresholds; sparse OVR fallback auto-scales with the GPU. +static inline size_t wilcoxon_max_smem_per_block() { + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { + return WILCOXON_FALLBACK_SMEM_PER_BLOCK; + } + static thread_local int cached_dev = -1; + static thread_local size_t cached_smem = 0; + if (device == cached_dev) return cached_smem; + int max_smem = 0; + if (cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, + device) != cudaSuccess) { + return WILCOXON_FALLBACK_SMEM_PER_BLOCK; + } + cached_dev = device; + cached_smem = (size_t)max_smem; + return cached_smem; +} + +// Max per-batch nnz: a batch is sorted in one CUB segmented call (int32 item +// count) and addressed with int offsets, so it must stay below INT_MAX. +constexpr size_t SAFE_BATCH_NNZ = STREAMING_SAFE_BATCH_NNZ; + +static inline int round_up_to_warp(int n) { + int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; +} + +/** Per-row stats codes for a pack of K groups. + * Writes stats_codes[r] = base_slot + group_idx(r) by offset binary search. */ +__global__ void fill_pack_stats_codes_kernel( + const int* __restrict__ pack_grp_offsets, int* __restrict__ stats_codes, + int K, int base_slot) { + int r = blockIdx.x * blockDim.x + threadIdx.x; + int pack_n_rows = pack_grp_offsets[K]; + if (r >= pack_n_rows) return; + int lo = 0, hi = K; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (pack_grp_offsets[m + 1] <= r) + lo = m + 1; + else + hi = m; + } + stats_codes[r] = base_slot + lo; +} + +// Per-group stats over compact CSR, decoupled for host-staged data. +// Slot comes from stats_codes[r] or fixed_slot; out-of-range slots are skipped. +__global__ void csr_compact_accumulate_kernel( + const float* __restrict__ d_data_f32, const int* __restrict__ d_indices, + const int* __restrict__ d_indptr, const int* __restrict__ d_stats_codes, + int fixed_slot, double* __restrict__ group_sums, + double* __restrict__ group_nnz, int n_target_rows, int n_cols, + int n_groups_stats, bool compute_sums, bool compute_nnz) { + int r = blockIdx.x; + if (r >= n_target_rows) return; + int slot = (d_stats_codes != nullptr) ? d_stats_codes[r] : fixed_slot; + if (slot < 0 || slot >= n_groups_stats) return; + int rs = d_indptr[r]; + int re = d_indptr[r + 1]; + for (int i = rs + threadIdx.x; i < re; i += blockDim.x) { + int c = d_indices[i]; + double v = (double)d_data_f32[i]; + if (compute_sums) atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); + if (compute_nnz && v != 0.0) + atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh new file mode 100644 index 00000000..2a25e26c --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -0,0 +1,413 @@ +#pragma once + +/** CSR-direct OVO pipeline: cache sorted reference columns once. + * Group sub-batches rank against that cache, matching the host-CSR path. */ +template +static void ovo_streaming_csr_impl( + const float* csr_data, const IndexT* csr_indices, const IndptrT* csr_indptr, + const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + DenseColumnBatchPlan group_batches = plan_dense_column_batches( + n_all_grp, n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "OVO device CSR group sub-batch"); + sub_batch_cols = group_batches.sub_batch_cols; + + std::vector h_offsets(n_groups + 1); + cuda_check(cudaMemcpy(h_offsets.data(), grp_offsets, + (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost), + "device OVO group offsets D2H"); + auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool run_large = t1.above_medium && t1.run_large; + bool run_huge = t1.above_medium && !run_large; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_offsets.data(), n_groups, OVO_MEDIUM_MAX); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + size_t max_ref_cols = 2147483647LL / (size_t)n_ref; + if (max_ref_cols == 0) { + throw std::runtime_error( + "OVO device CSR reference group exceeds CUB int item limit"); + } + int ref_cache_cols = std::min(n_cols, (int)max_ref_cols); + { + // Ref cache uses dense+sorted floats per column/ref row. + // Size to ~1/3 allocator budget, leaving room for group buffers. + size_t bytes_per_col = (size_t)n_ref * sizeof(float) * 2; + size_t target_bytes = rmm_available_device_bytes(1.0 / 3.0); + if (bytes_per_col > 0 && target_bytes >= bytes_per_col) { + size_t mem_cols = target_bytes / bytes_per_col; + if (mem_cols > 0 && mem_cols < (size_t)ref_cache_cols) { + ref_cache_cols = (int)mem_cols; + } + } + } + if (ref_cache_cols < 1) ref_cache_cols = 1; + + RmmScratchPool pool; + + size_t cub_temp_bytes = 0; + if (run_huge) { + size_t cub_grp_bytes = 0; + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO device CSR group sub-batch"); + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSR group segment count"); + cub_grp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); + cub_temp_bytes = cub_grp_bytes; + } + + // Clamp streams to budget: group slabs scale with cell count. + // Ref cache is allocated separately, so reserve its footprint first. + { + size_t per_stream = + sub_grp_items * sizeof(float) + + (run_huge ? sub_grp_items * sizeof(float) : 0) + + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) + : 0) + + (run_huge ? cub_temp_bytes : 0) + + (compute_tie_corr ? (size_t)sub_batch_cols * sizeof(double) : 0) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double); + size_t budget = rmm_available_device_bytes(0.8); + size_t ref_reserve = + 2 * (size_t)n_ref * (size_t)ref_cache_cols * sizeof(float); + budget = budget > ref_reserve ? budget - ref_reserve : 0; + n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); + } + + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + ScopedCudaStream ref_stream(cudaStreamNonBlocking); + + int* d_sort_group_ids = nullptr; + if (run_huge) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cuda_check(cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice), + "device OVO sort group ids H2D"); + } + + struct StreamBuf { + float* grp_dense; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].cub_temp = + run_huge ? pool.alloc(cub_temp_bytes) : nullptr; + // LARGE/HUGE share the ref tie base: allocate whenever correcting. + bufs[s].ref_tie_sums = + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (run_huge) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSR group segment buffer"); + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_extract = round_up_to_warp(std::max(n_ref, n_all_grp)); + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + for (int cache_col = 0; cache_col < n_cols; cache_col += ref_cache_cols) { + int cache_cols = std::min(ref_cache_cols, n_cols - cache_col); + size_t cache_ref_items = (size_t)n_ref * cache_cols; + int cache_ref_items_i32 = checked_cub_items( + cache_ref_items, "OVO device CSR reference cache"); + + ScopedCudaBuffer ref_dense_buf(cache_ref_items * sizeof(float)); + ScopedCudaBuffer ref_sorted_buf(cache_ref_items * sizeof(float)); + ScopedCudaBuffer ref_seg_offsets_buf((size_t)(cache_cols + 1) * + sizeof(int)); + float* d_ref_dense = (float*)ref_dense_buf.data(); + float* d_ref_sorted = (float*)ref_sorted_buf.data(); + int* d_ref_seg_offsets = (int*)ref_seg_offsets_buf.data(); + + cudaMemsetAsync(d_ref_dense, 0, cache_ref_items * sizeof(float), + ref_stream); + int tpb_ref_extract = round_up_to_warp(n_ref); + int ref_blk = (n_ref + tpb_ref_extract - 1) / tpb_ref_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, ref_row_ids, d_ref_dense, n_ref, + cache_col, cache_col + cache_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + + upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, ref_stream); + + size_t ref_cub_bytes = + cub_segmented_sortkeys_temp_bytes(cache_ref_items_i32, cache_cols); + ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); + cub_segmented_sortkeys(ref_cub_temp_buf.data(), ref_cub_bytes, + d_ref_dense, d_ref_sorted, cache_ref_items_i32, + cache_cols, d_ref_seg_offsets, + d_ref_seg_offsets + 1, ref_stream, + "device CSR OVO ref segmented sort"); + cuda_check(cudaStreamSynchronize(ref_stream), + "device CSR OVO ref sort sync"); + + int col = cache_col; + int cache_stop = cache_col + cache_cols; + int batch_idx = 0; + while (col < cache_stop) { + int sb_cols = std::min(sub_batch_cols, cache_stop - col); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO device CSR active group sub-batch"); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + const float* ref_sub = + d_ref_sorted + (size_t)(col - cache_col) * n_ref; + + cudaMemsetAsync(buf.grp_dense, 0, + sb_grp_items_actual * sizeof(float), stream); + { + int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, grp_row_ids, + buf.grp_dense, n_all_grp, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.cub_temp}; + ovo_dispatch_tiers(ref_sub, buf.grp_dense, grp_offsets, t1, sc, + d_sort_group_ids, n_sort_groups, cub_temp_bytes, + sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + sync_streams(streams, "OVO device CSR streaming"); + } +} + +/** CSC-direct OVO pipeline: extracts rows via lookup maps. + * Operates on native CSC input without converting the matrix. */ +template +static void ovo_streaming_csc_impl( + const float* csc_data, const IndexT* csc_indices, const IndptrT* csc_indptr, + const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + DenseColumnBatchPlan batches = plan_dense_column_batches( + std::max(n_ref, n_all_grp), n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "OVO device CSC sub-batch"); + sub_batch_cols = batches.sub_batch_cols; + + std::vector h_offsets(n_groups + 1); + cuda_check(cudaMemcpy(h_offsets.data(), grp_offsets, + (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost), + "device OVO group offsets D2H"); + auto t1 = make_ovo_tier_plan(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool run_large = t1.above_medium && t1.run_large; + bool run_huge = t1.above_medium && !run_large; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_offsets.data(), n_groups, OVO_MEDIUM_MAX); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "OVO device CSC reference sub-batch"); + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO device CSC group sub-batch"); + + size_t cub_ref_bytes = + cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); + size_t cub_temp_bytes = cub_ref_bytes; + if (run_huge) { + size_t cub_grp_bytes = 0; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSC group segment count"); + cub_grp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + // Clamp streams to per-stream scratch budget. + // Ref/group slabs scale with cell counts, so fixed counts can OOM. + { + size_t per_stream = + 2 * sub_ref_items * sizeof(float) + + (run_huge ? 2 : 1) * sub_grp_items * sizeof(float) + + (size_t)(sub_batch_cols + 1) * sizeof(int) + cub_temp_bytes + + (run_huge ? 2 * (size_t)n_sort_groups * sub_batch_cols * sizeof(int) + : 0) + + (compute_tie_corr ? (size_t)sub_batch_cols * sizeof(double) : 0) + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double); + size_t budget = rmm_available_device_bytes(0.8); + n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); + } + + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + int* d_sort_group_ids = nullptr; + if (run_huge) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cuda_check(cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice), + "device OVO sort group ids H2D"); + } + + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + // LARGE/HUGE share the ref tie base: allocate whenever correcting. + bufs[s].ref_tie_sums = + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (run_huge) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSC group segment buffer"); + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "OVO device CSC active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO device CSC active group sub-batch"); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, + n_ref, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + cub_segmented_sortkeys(buf.cub_temp, cub_temp_bytes, buf.ref_dense, + buf.ref_sorted, sb_ref_items_actual, sb_cols, + buf.ref_seg_offsets, buf.ref_seg_offsets + 1, + stream, "device CSC OVO ref segmented sort"); + + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, + n_all_grp, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + OvoTierScratch sc{buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.cub_temp}; + ovo_dispatch_tiers(buf.ref_sorted, buf.grp_dense, grp_offsets, t1, sc, + d_sort_group_ids, n_sort_groups, cub_temp_bytes, + sb_grp_items_actual, tpb_rank, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + sync_streams(streams, "OVO device CSC streaming"); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh new file mode 100644 index 00000000..a8aae1cb --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -0,0 +1,797 @@ +#pragma once + +struct OvoHostCsrPack { + int first; + int end; + int n_rows; + size_t nnz; + int sb_cols; +}; + +struct OvoHostCsrPackPlan { + std::vector packs; + int max_pack_rows = 0; + size_t max_pack_nnz = 0; + int max_pack_K = 0; + int max_pack_items = 0; + int max_pack_sb_cols = 0; + size_t max_sub_items = 0; +}; + +template +static OvoHostCsrPackPlan plan_ovo_host_csr_packs( + const int* h_grp_offsets, const IndptrT* h_grp_indptr_compact, + int n_all_grp, int n_test, int n_cols, int n_ref, int sub_batch_cols) { + OvoHostCsrPackPlan plan; + plan.max_pack_sb_cols = sub_batch_cols; + + int target_packs = N_STREAMS; + int target_rows = (n_all_grp + target_packs - 1) / target_packs; + if (target_rows < 1) target_rows = 1; + size_t budget_cap_rows = GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; + if ((size_t)target_rows > budget_cap_rows) + target_rows = (int)budget_cap_rows; + + constexpr size_t SAFE_PACK_NNZ = 1500000000; // < INT_MAX, CUB-safe + size_t pack_nnz_cap = SAFE_PACK_NNZ; + { + int target_streams = std::min(N_STREAMS, n_test); + if (target_streams < 1) target_streams = 1; + size_t dev_budget = rmm_available_device_bytes(0.9); + size_t ref_bytes = (size_t)n_ref * (size_t)n_cols * sizeof(float); + size_t reserve = (size_t)target_streams * OVO_PACK_FIXED_PER_STREAM; + size_t grp_avail = dev_budget > ref_bytes ? dev_budget - ref_bytes : 0; + size_t data_avail = grp_avail > reserve ? grp_avail - reserve : 0; + size_t cap = data_avail / ((size_t)target_streams * 2 * sizeof(float)); + if (cap < OVO_MIN_PACK_NNZ) cap = OVO_MIN_PACK_NNZ; + if (cap < pack_nnz_cap) pack_nnz_cap = cap; + } + + int cur_first = 0; + int cur_rows = 0; + size_t cur_nnz = 0; + for (int g = 0; g < n_test; g++) { + int n_g = h_grp_offsets[g + 1] - h_grp_offsets[g]; + size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - + h_grp_indptr_compact[h_grp_offsets[g]]); + int new_rows = cur_rows + n_g; + bool can_add = (cur_rows == 0) || (new_rows <= target_rows && + cur_nnz + nnz_g <= pack_nnz_cap); + if (!can_add) { + size_t sb_size = std::min( + (size_t)n_cols, GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + plan.packs.push_back( + {cur_first, g, cur_rows, cur_nnz, (int)sb_size}); + cur_first = g; + cur_rows = n_g; + cur_nnz = nnz_g; + } else { + cur_rows = new_rows; + cur_nnz += nnz_g; + } + } + if (cur_rows > 0) { + size_t sb_size = std::min((size_t)n_cols, + GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + plan.packs.push_back( + {cur_first, n_test, cur_rows, cur_nnz, (int)sb_size}); + } + + for (const OvoHostCsrPack& pk : plan.packs) { + int K = pk.end - pk.first; + if (pk.n_rows > plan.max_pack_rows) plan.max_pack_rows = pk.n_rows; + if (pk.nnz > plan.max_pack_nnz) plan.max_pack_nnz = pk.nnz; + if (K > plan.max_pack_K) plan.max_pack_K = K; + int pack_items = + checked_int_product((size_t)pk.n_rows, (size_t)pk.sb_cols, + "OVO host CSR pack dense slab"); + if (pack_items > plan.max_pack_items) plan.max_pack_items = pack_items; + checked_int_span(pk.nnz, "OVO host CSR pack compacted nnz"); + if (pk.sb_cols > plan.max_pack_sb_cols) + plan.max_pack_sb_cols = pk.sb_cols; + } + plan.max_sub_items = (size_t)plan.max_pack_items; + return plan; +} + +/** Host-streaming CSC OVO: send only each column sub-batch to GPU. + * Row maps/group offsets upload once; results scatter per sub-batch. */ +template +static void ovo_streaming_csc_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_ref_row_map, const int* h_grp_row_map, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, int n_ref, + int n_all_grp, int n_rows, int n_cols, int n_groups, int n_groups_stats, + bool compute_tie_corr, bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // Cap sub_batch_cols so neither the dense ref/group slabs (rows × + // sub_batch_cols, one CUB call) nor per-batch nnz exceed int32. + DenseColumnBatchPlan dense_batches = plan_dense_column_batches( + std::max(n_ref, n_all_grp), n_cols, sub_batch_cols, SAFE_BATCH_NNZ, + "OVO host CSC dense sub-batch"); + sub_batch_cols = dense_batches.sub_batch_cols; + size_t sparse_cap = SAFE_BATCH_NNZ; + ColumnBatchPlan batches = + plan_csc_column_batches(h_indptr, n_cols, sub_batch_cols, sparse_cap, + "OVO host CSC rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + + auto t1 = make_ovo_tier_plan(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool run_large = t1.above_medium && t1.run_large; + bool run_huge = t1.above_medium && !run_large; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (run_huge) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, OVO_MEDIUM_MAX); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "OVO host CSC reference sub-batch"); + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO host CSC group sub-batch"); + + // CUB temp + size_t cub_ref_bytes = + cub_segmented_sortkeys_temp_bytes(sub_ref_items_i32, sub_batch_cols); + size_t cub_temp_bytes = cub_ref_bytes; + if (run_huge) { + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO host CSC group segment count"); + size_t cub_grp_bytes = + cub_segmented_sortkeys_temp_bytes(sub_grp_items_i32, max_grp_seg); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + size_t max_nnz = batches.max_nnz; + constexpr size_t window_value_bytes = + sizeof(WilcoxonSparseWindowDTypes::value_type); + + // Clamp streams so per-stream scratch fits the budget: dense slabs scale + // with cell counts, so a fixed N_STREAMS would OOM at scale. + { + size_t per_stream = + sparse_window_nnz_bytes(max_nnz) + + 2 * sub_ref_items * window_value_bytes + + (run_huge ? 2 : 1) * sub_grp_items * window_value_bytes + + sparse_window_accum_bytes( + 2 * (size_t)n_groups * sub_batch_cols) + + (compute_nnz ? 2 : 1) * + sparse_window_accum_bytes( + (size_t)n_groups_stats * sub_batch_cols) + + cub_temp_bytes; + size_t budget = rmm_available_device_bytes(0.8); + n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); + } + + // pool first: streams drain before it frees their scratch (RAII order). + RmmScratchPool pool; + // Bounded staging avoids page-locking huge host CSC arrays and gives every + // dtype/index combination the same device footprint. + HostStagingRing stage(n_streams, max_nnz); + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + int* d_all_offsets = upload_batch_offsets(batches, pool); + + // Row maps + group offsets + stats codes (uploaded once) + int* d_ref_row_map = pool.alloc(n_rows); + int* d_grp_row_map = pool.alloc(n_rows); + int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); + int* d_sort_group_ids = nullptr; + cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + if (run_huge) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* d_sparse_data_f32; + int* d_sparse_indices; + int* d_indptr; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + // LARGE/HUGE share the ref tie base: allocate whenever correcting. + bufs[s].ref_tie_sums = + compute_tie_corr ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_nnz = pool.alloc( + compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); + if (run_huge) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO host CSC stream group segment count"); + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_nnz, /*compute_totals=*/false, cast_use_gmem); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "OVO host CSC active reference sub-batch"); + int sb_grp_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO host CSC active group sub-batch"); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + size_t nnz = (size_t)(ptr_end - ptr_start); + int nnz_i = checked_int_span(nnz, "OVO host CSC active batch nnz"); + + // Cast-copy column batch into pinned staging, bulk H2D; the event lets + // the next copy overlap compute. + stage.wait(s); + host_cast_copy_slice(h_data, h_indices, (size_t)ptr_start, nnz_i, + stage.get<0>(s), stage.get<1>(s)); + cudaMemcpyAsync(buf.d_sparse_data_f32, stage.get<0>(s), + nnz * sizeof(float), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, stage.get<1>(s), + nnz * sizeof(int), cudaMemcpyHostToDevice, stream); + stage.record(s, stream); + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_indptr, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Data already f32 on device: accumulate stats (cast is f32->f32 + // no-op). + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_f32, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_nnz, + nullptr, nullptr, sb_cols, n_groups_stats, compute_nnz, + /*compute_totals=*/false, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, + stream); + + // Extract ref from CSC via row_map, sort + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_ref_row_map, buf.ref_dense, n_ref, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + cub_segmented_sortkeys(buf.cub_temp, cub_temp_bytes, buf.ref_dense, + buf.ref_sorted, sb_ref_actual, sb_cols, + buf.ref_seg_offsets, buf.ref_seg_offsets + 1, + stream, "host CSC OVO ref segmented sort"); + + // Extract grp from CSC via row_map + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_grp_row_map, buf.grp_dense, n_all_grp, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + // Tier dispatch: sort grp + rank + OvoTierScratch sc{buf.ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, buf.grp_sorted, + buf.grp_seg_offsets, buf.grp_seg_ends, + buf.cub_temp}; + ovo_dispatch_tiers(buf.ref_sorted, buf.grp_dense, d_grp_offsets, t1, sc, + d_sort_group_ids, n_sort_groups, cub_temp_bytes, + sb_grp_actual, tpb_rank, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, stream); + + // D2D: scatter sub-batch results into caller's GPU buffers + scatter_cols_2d(d_rank_sums + col, buf.d_rank_sums, n_groups, n_cols, + sb_cols, stream); + if (compute_tie_corr) { + scatter_cols_2d(d_tie_corr + col, buf.d_tie_corr, n_groups, n_cols, + sb_cols, stream); + } + scatter_cols_2d(d_group_sums + col, buf.d_group_sums, n_groups_stats, + n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(d_group_nnz + col, buf.d_group_nnz, n_groups_stats, + n_cols, sb_cols, stream); + } + + col += sb_cols; + batch_idx++; + } + + sync_streams(streams, "wilcoxon streaming"); +} + +/** Host CSR OVO pipeline with no full-matrix page-lock. + * Pinned pack staging feeds dense extract, sort, rank vs cached ref, scatter. + */ +template +static void ovo_streaming_csr_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_full_rows, const int* h_ref_row_ids, int n_ref, + const int* h_grp_row_ids, const int* h_grp_offsets, int n_all_grp, + int n_test, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, + double* d_group_nnz, int n_cols, int n_groups_stats, bool compute_tie_corr, + bool compute_nnz, bool compute_sums, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; + + // Compacted indptrs on host. IndptrT for grp (can exceed 2^31 nnz when + // large/dense); ref stays int32 (n_ref × n_cols ≪ 2B, matches CUB temp). + std::vector h_ref_indptr_compact(n_ref + 1); + h_ref_indptr_compact[0] = 0; + for (int i = 0; i < n_ref; i++) { + int r = h_ref_row_ids[i]; + IndptrT row_nnz = h_indptr[r + 1] - h_indptr[r]; + if ((size_t)row_nnz > (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "OVO host CSR reference row exceeds int32 compacted nnz limit"); + } + int nnz_i = (int)row_nnz; + if ((size_t)h_ref_indptr_compact[i] + (size_t)nnz_i > + (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "OVO host CSR reference compacted nnz exceeds int32 limit"); + } + h_ref_indptr_compact[i + 1] = h_ref_indptr_compact[i] + nnz_i; + } + int ref_nnz = h_ref_indptr_compact[n_ref]; + + // grp: compacted indptr over concatenated test-group rows. + std::vector h_grp_indptr_compact(n_all_grp + 1); + h_grp_indptr_compact[0] = 0; + for (int i = 0; i < n_all_grp; i++) { + int r = h_grp_row_ids[i]; + IndptrT nnz_i = h_indptr[r + 1] - h_indptr[r]; + h_grp_indptr_compact[i + 1] = h_grp_indptr_compact[i] + nnz_i; + } + + OvoHostCsrPackPlan pack_plan = plan_ovo_host_csr_packs( + h_grp_offsets, h_grp_indptr_compact.data(), n_all_grp, n_test, n_cols, + n_ref, sub_batch_cols); + const std::vector& packs = pack_plan.packs; + int max_pack_rows = pack_plan.max_pack_rows; + size_t max_pack_nnz = pack_plan.max_pack_nnz; + int max_pack_K = pack_plan.max_pack_K; + int max_pack_sb_cols = pack_plan.max_pack_sb_cols; + size_t max_sub_items = pack_plan.max_sub_items; + if (max_pack_rows == 0) return; + + RmmScratchPool pool; + ScopedCudaStream ref_stream(cudaStreamNonBlocking); + + if (compute_sums) { + cudaMemsetAsync(d_group_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double), + ref_stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, + (size_t)n_groups_stats * n_cols * sizeof(double), + ref_stream); + } + + // No full-matrix page-lock: large cudaHostRegister was seconds per call. + // Gather reads pageable CSR and pins only small staging buffers. + (void)n_full_rows; + + // Pinned staging for the reference gather (compacted f32 vals + int32 + // cols). Uninitialized: the gather overwrites it, so skip a multi-GB zero. + size_t ref_stage_n = ref_nnz ? (size_t)ref_nnz : 1; + PinnedRing ref_stage(1, ref_stage_n); + + // Upload row_ids + compacted indptrs + group boundaries + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + IndptrT* d_grp_indptr_compact = pool.alloc(n_all_grp + 1); + int* d_grp_offsets_full = pool.alloc(n_test + 1); + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_indptr_compact, h_grp_indptr_compact.data(), + (n_all_grp + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets_full, h_grp_offsets, (n_test + 1) * sizeof(int), + cudaMemcpyHostToDevice); + + // Phase 1: ref setup with scoped scratch; sorted cache persists. + // Build by column chunk so CUB item counts and extract scratch stay + // bounded. + size_t ref_items = (size_t)n_ref * (size_t)n_cols; + if (ref_items > std::numeric_limits::max() / (2 * sizeof(float))) { + throw std::runtime_error( + "OVO host CSR dense reference cache size overflows size_t"); + } + size_t ref_avail = rmm_available_device_bytes(0.9); + if (ref_avail > 0 && ref_items * sizeof(float) > ref_avail) { + throw std::runtime_error( + "OVO host CSR sorted reference cache requires more GPU memory than " + "is available; use native CSC/device sparse input or reduce " + "genes/reference size"); + } + int ref_chunk_cols = + n_ref > 0 + ? (int)std::min((size_t)n_cols, SAFE_BATCH_NNZ / (size_t)n_ref) + : n_cols; + if (ref_chunk_cols < 1) ref_chunk_cols = 1; + size_t ref_chunk_items = (size_t)n_ref * (size_t)ref_chunk_cols; + int ref_chunk_items_i32 = + checked_cub_items(ref_chunk_items, "OVO host CSR ref column chunk"); + float* d_ref_sorted = pool.alloc(ref_items); + { + ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); + ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); + ScopedCudaBuffer ref_indptr_buf((n_ref + 1) * sizeof(int)); + ScopedCudaBuffer ref_dense_buf(ref_chunk_items * sizeof(float)); + ScopedCudaBuffer ref_seg_buf((ref_chunk_cols + 1) * sizeof(int)); + + float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); + int* d_ref_indices = (int*)ref_indices_buf.data(); + int* d_ref_indptr = (int*)ref_indptr_buf.data(); + float* d_ref_dense = (float*)ref_dense_buf.data(); + int* d_ref_seg = (int*)ref_seg_buf.data(); + + cudaMemcpy(d_ref_indptr, h_ref_indptr_compact.data(), + (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); + + // Host-gather ref rows into pinned staging, bulk H2D, accumulate stats. + if (n_ref > 0 && ref_nnz > 0) { + host_gather_rows_compact(h_data, h_indices, h_indptr, h_ref_row_ids, + h_ref_indptr_compact.data(), 0, n_ref, + ref_stage.get<0>(0), ref_stage.get<1>(0)); + cuda_check(cudaMemcpyAsync(d_ref_data_f32, ref_stage.get<0>(0), + (size_t)ref_nnz * sizeof(float), + cudaMemcpyHostToDevice, ref_stream), + "OVO host CSR ref staged vals H2D"); + cuda_check(cudaMemcpyAsync(d_ref_indices, ref_stage.get<1>(0), + (size_t)ref_nnz * sizeof(int), + cudaMemcpyHostToDevice, ref_stream), + "OVO host CSR ref staged cols H2D"); + ref_stage.record(0, ref_stream); + if (compute_sums || compute_nnz) { + csr_compact_accumulate_kernel<<>>( + d_ref_data_f32, d_ref_indices, d_ref_indptr, + /*d_stats_codes=*/nullptr, /*fixed_slot=*/n_test, + d_group_sums, d_group_nnz, n_ref, n_cols, n_groups_stats, + compute_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_compact_accumulate_kernel); + } + } + + size_t ref_cub_bytes = cub_segmented_sortkeys_temp_bytes( + ref_chunk_items_i32, ref_chunk_cols); + ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); + + // Extract + segment-sort the reference per column chunk. + for (int cs = 0; cs < n_cols; cs += ref_chunk_cols) { + int ce = std::min(cs + ref_chunk_cols, n_cols); + int cc = ce - cs; + size_t chunk_items = (size_t)n_ref * (size_t)cc; + cudaMemsetAsync(d_ref_dense, 0, chunk_items * sizeof(float), + ref_stream); + csr_extract_dense_identity_rows_unsorted_kernel + <<>>( + d_ref_data_f32, d_ref_indices, d_ref_indptr, d_ref_dense, + n_ref, cs, ce); + CUDA_CHECK_LAST_ERROR( + csr_extract_dense_identity_rows_unsorted_kernel); + upload_linear_offsets(d_ref_seg, cc, n_ref, ref_stream); + cub_segmented_sortkeys( + cub_temp_buf.data(), ref_cub_bytes, d_ref_dense, + d_ref_sorted + (size_t)cs * (size_t)n_ref, (int)chunk_items, cc, + d_ref_seg, d_ref_seg + 1, ref_stream, + "host CSR OVO ref segmented sort"); + } + cuda_check(cudaStreamSynchronize(ref_stream), + "host CSR OVO ref sort sync"); + } // ref scratch drops here + + // Phase 2: Per-pack streaming + auto t1 = make_ovo_tier_plan(h_grp_offsets, n_test); + bool may_need_cub = (t1.max_grp_size > OVO_LARGE_MAX); + + constexpr int MAX_GROUP_STREAMS = 4; + int n_streams = MAX_GROUP_STREAMS; + if (n_test < n_streams) n_streams = n_test; + if (n_streams < 1) n_streams = 1; + if ((int)packs.size() < n_streams) n_streams = (int)packs.size(); + if (n_streams < 1) n_streams = 1; + + size_t cub_grp_bytes = 0; + if (may_need_cub && max_sub_items > 0) { + int max_sub_items_i32 = + checked_cub_items(max_sub_items, "OVO host CSR group pack"); + int max_segments = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR max group segment count"); + cub_grp_bytes = + cub_segmented_sortkeys_temp_bytes(max_sub_items_i32, max_segments); + } + int max_pack_kernel_seg = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR pack segment buffer"); + constexpr size_t window_value_bytes = + sizeof(WilcoxonSparseWindowDTypes::value_type); + + // Clamp streams to the post-ref free-memory budget. + // Per-stream pack buffers dominate; fewer streams reduce overlap only. + { + size_t per_stream = + sparse_window_nnz_bytes(max_pack_nnz) + + (size_t)(max_pack_rows + 1) * sizeof(int) // grp indptr + + (size_t)max_pack_rows * sizeof(int) // stats codes + + (size_t)(max_pack_K + 1) * sizeof(int) // pack grp offsets + + max_sub_items * window_value_bytes // grp dense + + sparse_window_accum_bytes( + 2 * (size_t)max_pack_K * max_pack_sb_cols) // rank+tie + + sparse_window_accum_bytes( + (size_t)max_pack_sb_cols) // ref tie + + + (may_need_cub + ? max_sub_items * window_value_bytes // grp sorted + + (size_t)max_pack_K * sizeof(int) // sort ids + + 2 * (size_t)max_pack_kernel_seg * sizeof(int) // segs + + cub_grp_bytes // cub temp + : 0); + size_t budget = rmm_available_device_bytes(0.9); + n_streams = clamp_streams_by_budget(n_streams, per_stream, budget); + } + + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + struct StreamBuf { + float* d_grp_data_f32; + int* d_grp_indices; + int* d_grp_indptr; + int* d_pack_grp_offsets; + int* d_pack_stats_codes; + float* d_grp_dense; + float* d_grp_sorted; + double* d_ref_tie_sums; + int* d_sort_group_ids; + int* d_grp_seg_offsets; + int* d_grp_seg_ends; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indptr = pool.alloc(max_pack_rows + 1); + bufs[s].d_pack_grp_offsets = pool.alloc(max_pack_K + 1); + bufs[s].d_pack_stats_codes = pool.alloc(max_pack_rows); + bufs[s].d_grp_dense = pool.alloc(max_sub_items); + bufs[s].d_ref_tie_sums = pool.alloc(max_pack_sb_cols); + bufs[s].d_rank_sums = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + if (may_need_cub) { + bufs[s].d_grp_sorted = pool.alloc(max_sub_items); + bufs[s].d_sort_group_ids = pool.alloc(max_pack_K); + bufs[s].d_grp_seg_offsets = pool.alloc(max_pack_kernel_seg); + bufs[s].d_grp_seg_ends = pool.alloc(max_pack_kernel_seg); + bufs[s].cub_temp = pool.alloc(cub_grp_bytes); + } else { + bufs[s].d_grp_sorted = nullptr; + bufs[s].d_sort_group_ids = nullptr; + bufs[s].d_grp_seg_offsets = nullptr; + bufs[s].d_grp_seg_ends = nullptr; + bufs[s].cub_temp = nullptr; + } + } + + // Rolling pinned staging fills pack device buffers in <= stage_cap nnz + // blocks. This keeps page-locked footprint small while extra slots overlap + // H2D. + size_t stage_cap = std::min(max_pack_nnz, STAGE_RING_NNZ_CAP); + int ring_slots = n_streams + 2; + HostStagingRing stage(ring_slots, stage_cap); + int stage_slot = 0; + + for (int p = 0; p < (int)packs.size(); p++) { + const OvoHostCsrPack& pack = packs[p]; + int K = pack.end - pack.first; + if (K == 0 || pack.n_rows == 0) continue; + OvoTierPlan pack_t1 = make_ovo_tier_plan(h_grp_offsets + pack.first, K); + int pack_tpb_rank = round_up_to_warp( + std::min(pack_t1.max_grp_size, MAX_THREADS_PER_BLOCK)); + // HUGE skips groups MEDIUM already handled (≤ OVO_MEDIUM_MAX). + int pack_huge_skip_le = OVO_MEDIUM_MAX; + std::vector h_sort_group_ids; + int pack_n_sort_groups = K; + if (pack_t1.above_medium && !pack_t1.run_large) { + h_sort_group_ids = make_sort_group_ids(h_grp_offsets + pack.first, + K, pack_huge_skip_le); + pack_n_sort_groups = (int)h_sort_group_ids.size(); + } + + int s = p % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + if (pack_t1.above_medium && !pack_t1.run_large) { + cudaMemcpyAsync(buf.d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + int row_start = h_grp_offsets[pack.first]; + int pack_rows = pack.n_rows; + int pack_sb = pack.sb_cols; + + // Rebase pack's output indptr (IndptrT → int32: pack nnz is bounded by + // GROUP_DENSE_BUDGET so fits). + { + int count = pack_rows + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel + <<>>( + d_grp_indptr_compact, buf.d_grp_indptr, row_start, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Per-pack group offsets on GPU — needed for stats codes. + { + int count = K + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + d_grp_offsets_full, buf.d_pack_grp_offsets, pack.first, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + { + int blk = (pack_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_pack_stats_codes_kernel<<>>( + buf.d_pack_grp_offsets, buf.d_pack_stats_codes, K, pack.first); + CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); + } + + // Host-gather pack rows into rolling staging blocks, then H2D by + // offset. Stats accumulate once over the full device-resident pack. + if (pack.nnz > 0) { + IndptrT pack_base = h_grp_indptr_compact[row_start]; + int rb0 = 0; + while (rb0 < pack_rows) { + IndptrT blk_base = h_grp_indptr_compact[row_start + rb0]; + int rb1 = rb0 + 1; + while (rb1 < pack_rows && + (size_t)(h_grp_indptr_compact[row_start + rb1 + 1] - + blk_base) <= stage_cap) + rb1++; + size_t blk_nnz = + (size_t)(h_grp_indptr_compact[row_start + rb1] - blk_base); + size_t dev_off = (size_t)(blk_base - pack_base); + int slot = stage_slot % ring_slots; + stage_slot++; + // wait drains a prior H2D out of this slot before we overwrite + // it; the event lets the next gather overlap the in-flight H2D. + stage.wait(slot); + host_gather_rows_compact( + h_data, h_indices, h_indptr, + h_grp_row_ids + row_start + rb0, + h_grp_indptr_compact.data() + row_start + rb0, blk_base, + rb1 - rb0, stage.get<0>(slot), stage.get<1>(slot)); + cuda_check( + cudaMemcpyAsync(buf.d_grp_data_f32 + dev_off, + stage.get<0>(slot), blk_nnz * sizeof(float), + cudaMemcpyHostToDevice, stream), + "OVO host CSR pack staged vals H2D"); + cuda_check( + cudaMemcpyAsync(buf.d_grp_indices + dev_off, + stage.get<1>(slot), blk_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream), + "OVO host CSR pack staged cols H2D"); + stage.record(slot, stream); + rb0 = rb1; + } + if (compute_sums || compute_nnz) { + csr_compact_accumulate_kernel<<>>( + buf.d_grp_data_f32, buf.d_grp_indices, buf.d_grp_indptr, + buf.d_pack_stats_codes, /*fixed_slot=*/-1, d_group_sums, + d_group_nnz, pack_rows, n_cols, n_groups_stats, + compute_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_compact_accumulate_kernel); + } + } + + int col = 0; + while (col < n_cols) { + int sb_cols = std::min(pack_sb, n_cols - col); + int sb_items = + checked_int_product((size_t)pack_rows, (size_t)sb_cols, + "OVO host CSR active group sub-batch"); + + cudaMemsetAsync(buf.d_grp_dense, 0, sb_items * sizeof(float), + stream); + csr_extract_dense_identity_rows_unsorted_kernel + <<>>( + buf.d_grp_data_f32, buf.d_grp_indices, buf.d_grp_indptr, + buf.d_grp_dense, pack_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR( + csr_extract_dense_identity_rows_unsorted_kernel); + + const float* ref_sub = d_ref_sorted + (size_t)col * n_ref; + + OvoTierScratch sc{buf.d_ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, buf.d_grp_sorted, + buf.d_grp_seg_offsets, buf.d_grp_seg_ends, + buf.cub_temp}; + ovo_dispatch_tiers(ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + pack_t1, sc, buf.d_sort_group_ids, + pack_n_sort_groups, cub_grp_bytes, sb_items, + pack_tpb_rank, n_ref, pack_rows, sb_cols, K, + compute_tie_corr, stream); + + scatter_cols_2d(d_rank_sums + (size_t)pack.first * n_cols + col, + buf.d_rank_sums, K, n_cols, sb_cols, stream); + if (compute_tie_corr) { + scatter_cols_2d(d_tie_corr + (size_t)pack.first * n_cols + col, + buf.d_tie_corr, K, n_cols, sb_cols, stream); + } + + col += sb_cols; + } + } + + sync_streams(streams, "ovo csr host streaming"); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh new file mode 100644 index 00000000..30170651 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -0,0 +1,214 @@ +#pragma once + +#include + +#include "../sparse_extract/sparse_extract.cuh" + +/** Build CUB segmented-sort ranges for HUGE-band groups. + * Ranges point into the original dense group layout. */ +__global__ void build_huge_seg_offsets_kernel( + const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, + int* __restrict__ begins, int* __restrict__ ends, int n_all_grp, + int n_sort_groups, int sb_cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = sb_cols * n_sort_groups; + if (idx >= total) return; + + int c = idx / n_sort_groups; + int local = idx % n_sort_groups; + int g = group_ids[local]; + int base = c * n_all_grp; + begins[idx] = base + grp_offsets[g]; + ends[idx] = base + grp_offsets[g + 1]; +} + +template +__global__ void dense_ovo_group_stats_kernel( + const T* __restrict__ ref_dense, const T* __restrict__ grp_dense, + const int* __restrict__ grp_codes, double* __restrict__ group_sums, + double* __restrict__ group_sum_sq, double* __restrict__ group_nnz, + int n_ref, int n_all_grp, int sb_cols, int n_groups, bool compute_nnz) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int ref_slot = n_groups; + const T* ref_col = ref_dense + (size_t)col * n_ref; + const T* grp_col = grp_dense + (size_t)col * n_all_grp; + + for (int row = threadIdx.x; row < n_ref; row += blockDim.x) { + double v = (double)ref_col[row]; + atomicAdd(&group_sums[(size_t)ref_slot * sb_cols + col], v); + atomicAdd(&group_sum_sq[(size_t)ref_slot * sb_cols + col], v * v); + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)ref_slot * sb_cols + col], 1.0); + } + } + + for (int row = threadIdx.x; row < n_all_grp; row += blockDim.x) { + int g = grp_codes[row]; + if (g < 0 || g >= n_groups) continue; + double v = (double)grp_col[row]; + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + atomicAdd(&group_sum_sq[(size_t)g * sb_cols + col], v * v); + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } +} + +/** Sizing knobs for LARGE/HUGE dispatch. + * LARGE uses fused smem sort; HUGE uses CUB sort plus pre-sorted rank. */ +struct OvoTierPlan { + int max_grp_size = 0; + bool run_medium = false; // MEDIUM band: any group ≤ OVO_MEDIUM_MAX + bool run_large = false; // LARGE band: (OVO_MEDIUM_MAX, OVO_LARGE_MAX] + bool above_medium = false; // at least one group exceeds OVO_MEDIUM_MAX + int large_padded = 0; + int large_tpb = 0; + size_t large_smem = 0; +}; + +// Single OVO tier planner shared by dense and all sparse implementations. +// MEDIUM co-launches; LARGE falls back to HUGE if smem exceeds device limits. +static OvoTierPlan make_ovo_tier_plan(const int* h_grp_offsets, int n_groups) { + OvoTierPlan c; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > c.max_grp_size) c.max_grp_size = sz; + if (sz <= OVO_MEDIUM_MAX) c.run_medium = true; + if (sz > OVO_MEDIUM_MAX) c.above_medium = true; + } + + // run_large: fused smem-sort fast path for groups > MEDIUM but <= LARGE. + c.run_large = c.above_medium && (c.max_grp_size <= OVO_LARGE_MAX); + if (c.run_large) { + c.large_padded = 1; + while (c.large_padded < c.max_grp_size) c.large_padded <<= 1; + c.large_tpb = std::min(c.large_padded, MAX_THREADS_PER_BLOCK); + // dynamic smem = grp_smem only; warp_buf is static in the kernel. + c.large_smem = (size_t)c.large_padded * sizeof(float); + // Device-adapt fused-sort smem to the per-block limit. + // If it no longer fits, fall back to HUGE with no smem cap. + if (c.large_smem > wilcoxon_max_smem_per_block()) { + c.run_large = false; + } + } + return c; +} + +static std::vector make_sort_group_ids(const int* h_grp_offsets, + int n_groups, int skip_n_grp_le) { + std::vector ids; + ids.reserve(n_groups); + for (int g = 0; g < n_groups; ++g) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (skip_n_grp_le > 0 && sz <= skip_n_grp_le) continue; + ids.push_back(g); + } + return ids; +} + +static inline void launch_ref_tie_sums(const float* ref_sorted, + double* ref_tie_sums, int n_ref, + int sb_cols, cudaStream_t stream) { + ref_tie_sum_kernel<<>>( + ref_sorted, ref_tie_sums, n_ref, sb_cols); + CUDA_CHECK_LAST_ERROR(ref_tie_sum_kernel); +} + +static inline void launch_ovo_medium( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + constexpr int tpb = 256; + size_t smem = (size_t)OVO_MEDIUM_MAX * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + dim3 grid(sb_cols, K); + ovo_rank_medium_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le, + OVO_MEDIUM_MAX); + CUDA_CHECK_LAST_ERROR(ovo_rank_medium_kernel); +} + +// Per-stream scratch for ovo_dispatch_tiers (one set per CUDA stream). +// grp_sorted/grp_seg_*/grp_cub_temp are HUGE-band only; may be null otherwise. +struct OvoTierScratch { + double* ref_tie_sums; // [sb_cols] pre-computed reference tie sums, or null + double* sub_rank_sums; // [n_groups * sb_cols] rank-sum output accumulator + double* sub_tie_corr; // [n_groups * sb_cols] tie-correction output + float* grp_sorted; // HUGE: [n_all_grp * sb_cols] sorted group values + int* grp_seg_offsets; // HUGE: CUB segment begins + int* grp_seg_ends; // HUGE: CUB segment ends + uint8_t* grp_cub_temp; // HUGE: CUB scratch +}; + +// Single OVO ranking engine shared by dense and all sparse host/device paths. +// Callers differ only in how they produce ref_sorted and grp_dense. +static inline void ovo_dispatch_tiers( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const OvoTierPlan& plan, const OvoTierScratch& sc, + const int* d_sort_group_ids, int n_sort_groups, size_t grp_cub_temp_bytes, + int sb_grp_items_actual, int tpb_rank, int n_ref, int n_all_grp, + int sb_cols, int n_groups, bool compute_tie_corr, cudaStream_t stream) { + // No-tie fast path: rank unsorted group values vs sorted ref (U-identity). + // Skips group sort and all tier kernels. + if (!compute_tie_corr) { + constexpr int VS_REF_BLOCK = 256; + dim3 grid(sb_cols, n_groups); + ovo_rank_dense_vs_ref_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, sc.sub_rank_sums, n_ref, + n_all_grp, sb_cols, n_groups); + CUDA_CHECK_LAST_ERROR(ovo_rank_dense_vs_ref_kernel); + return; + } + bool run_large = plan.above_medium && plan.run_large; + bool run_huge = plan.above_medium && !run_large; + + // All tiers share the precomputed reference tie base; compute once/column. + if (compute_tie_corr) { + launch_ref_tie_sums(ref_sorted, sc.ref_tie_sums, n_ref, sb_cols, + stream); + } + // MEDIUM handles every group <= OVO_MEDIUM_MAX (skip_n_grp_le = 0); + // LARGE/HUGE take the groups above MEDIUM. + if (plan.run_medium) { + launch_ovo_medium(ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, /*skip=*/0, + stream); + } + + int upper_skip_le = plan.above_medium ? OVO_MEDIUM_MAX : 0; + if (plan.above_medium && run_large) { + dim3 grid(sb_cols, n_groups); + ovo_rank_sorted_kernel + <<>>( + ref_sorted, grp_dense, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, plan.large_padded, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_rank_sorted_kernel); + } else if (run_huge) { + int sb_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, + "OVO active group segment count"); + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_huge_seg_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, sc.grp_seg_offsets, sc.grp_seg_ends, + n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_huge_seg_offsets_kernel); + + cub_segmented_sortkeys(sc.grp_cub_temp, grp_cub_temp_bytes, grp_dense, + sc.grp_sorted, sb_grp_items_actual, sb_grp_seg, + sc.grp_seg_offsets, sc.grp_seg_ends, stream, + "OVO huge-tier group segmented sort"); + + dim3 grid(sb_cols, n_groups); + ovo_rank_sorted_kernel<<>>( + ref_sorted, sc.grp_sorted, grp_offsets, sc.ref_tie_sums, + sc.sub_rank_sums, sc.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, /*large_padded=*/0, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_rank_sorted_kernel); + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh new file mode 100644 index 00000000..fd8e1cdc --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -0,0 +1,106 @@ +#pragma once + +#include "../sparse_extract/sparse_extract.cuh" + +/** Count nonzeros per column from CSR. One thread per row. */ +template +__global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, + unsigned int* __restrict__ col_counts, + int n_rows, int n_cols) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)indices[p]; + if (c >= 0 && c < n_cols) atomicAdd(&col_counts[c], 1u); + } +} + +// CRITICAL: dense OVR gmem fallback is load-bearing for large n_groups. +// Shared-memory thresholds are device-queried; oversized smem would not launch. +static size_t ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(n_groups + 32) * sizeof(double); + if (need <= wilcoxon_max_smem_per_block()) { + use_gmem = false; + return need; + } + // Fall back to global memory accumulators; only need warp buf in smem + use_gmem = true; + return 32 * sizeof(double); +} + +/** CRITICAL: sparse OVR gmem fallback is required for Perturb-seq-scale groups. + * Shared-memory thresholds are device-queried; oversized smem cannot launch. + */ +static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); + if (need <= wilcoxon_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 32 * sizeof(double); +} + +/** Fill sort values with row indices [0, 1, ..., n_rows-1] per column. + * Grid: (n_cols,), block: 256 threads. */ +__global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + int* out = vals + (long long)col * n_rows; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + out[i] = i; + } +} + +/** Read one dense column batch into f32 F-order for segmented sort. + * F-order is identity cast; C-order reads into F-order while casting. */ +template +__global__ void dense_block_to_f32_kernel(const T* __restrict__ stg, + float* __restrict__ out, int n_rows, + int sb_cols, bool f_order) { + const long long total = (long long)n_rows * sb_cols; + const long long stride = (long long)gridDim.x * blockDim.x; + for (long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x; + idx < total; idx += stride) { + if (f_order) { + out[idx] = (float)stg[idx]; + } else { + int col = (int)(idx / n_rows); + int row = (int)(idx % n_rows); + out[idx] = (float)stg[(long long)row * sb_cols + col]; + } + } +} + +/** Accumulate dense batch per-group sums and optional nnz in f64. + * Reads native staging so means match Aggregate; ranking cast is separate. */ +template +__global__ void dense_group_accumulate_kernel( + const T* __restrict__ stg, const int* __restrict__ group_codes, + double* __restrict__ group_sums, double* __restrict__ group_nnz, + double* __restrict__ total_sums, double* __restrict__ total_nnz, int n_rows, + int sb_cols, int n_groups, bool f_order, bool compute_nnz, + bool compute_totals) { + int col = blockIdx.x; + if (col >= sb_cols) return; + for (int row = threadIdx.x; row < n_rows; row += blockDim.x) { + double v = f_order ? (double)stg[(long long)col * n_rows + row] + : (double)stg[(long long)row * sb_cols + col]; + if (compute_totals) { + atomicAdd(&total_sums[col], v); + if (compute_nnz && v != 0.0) { + atomicAdd(&total_nnz[col], 1.0); + } + } + int g = group_codes[row]; + if (g < 0 || g >= n_groups) continue; + atomicAdd(&group_sums[(long long)g * sb_cols + col], v); + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(long long)g * sb_cols + col], 1.0); + } + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh new file mode 100644 index 00000000..8cff570d --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -0,0 +1,1511 @@ +#pragma once + +// Host-streaming CSC OVR: sort only stored nonzeros per column. +// GPU memory is O(max_batch_nnz), not O(n_rows * n_cols). +template +static void ovr_sparse_csc_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, + double* d_total_sums, double* d_total_nnz, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, bool compute_nnz, bool compute_totals, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // Bound each batch's nnz: CUB item counts stay within int32 + per-stream + // sort buffers fit the budget (column counts free from CSC indptr). + size_t cap = SAFE_BATCH_NNZ; + { + constexpr size_t BYTES_PER_NNZ = + sizeof(InT) + 2 * sizeof(float) + 2 * sizeof(IndexT) + 8; + size_t mem_cap = + rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + } + + ColumnBatchPlan batches = + plan_csc_column_batches(h_indptr, n_cols, sub_batch_cols, cap, + "OVR host CSC rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + size_t max_nnz = batches.max_nnz; + + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + int max_nnz_i32 = + checked_cub_items(max_nnz, "OVR host CSC sparse sub-batch nnz"); + cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(max_nnz_i32, sub_batch_cols); + } + + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; + size_t total_nnz = (size_t)h_indptr[n_cols]; + size_t direct_pin_bytes = total_nnz * (sizeof(InT) + sizeof(IndexT)); + bool use_bounded_stage = + direct_pin_bytes > HOST_STREAMING_DIRECT_PIN_LIMIT_BYTES; + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + std::unique_ptr> stage; + if (use_bounded_stage) { + stage.reset(new PinnedRing(n_streams, max_nnz)); + } else { + pin_data = HostRegisterGuard(const_cast(h_data), + total_nnz * sizeof(InT)); + pin_indices = HostRegisterGuard(const_cast(h_indices), + total_nnz * sizeof(IndexT)); + } + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + IndexT* d_sparse_indices; + int* idx_i32; // int32 sort-val scratch; only used when IndexT != int + int* d_seg_offsets; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_nnz; + double* d_total_sums; + double* d_total_nnz; + double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].idx_i32 = + (sizeof(IndexT) > sizeof(int)) ? pool.alloc(max_nnz) : nullptr; + bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_total_sums = + compute_totals ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].d_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(sub_batch_cols) + : nullptr; + } + + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + // Pre-compute rebased per-batch offsets, upload once (no per-batch H2D). + int* d_all_offsets = upload_batch_offsets(batches, pool); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); + + // gmem mode: rank kernel accumulates into rank_sums directly, needs a + // per-stream nz_count scratch buffer sized (n_groups, sb_cols). + for (int s = 0; s < n_streams; s++) { + if (rank_use_gmem) { + bufs[s].d_nz_scratch = + pool.alloc((size_t)n_groups * sub_batch_cols); + } else { + bufs[s].d_nz_scratch = nullptr; + } + } + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "OVR host CSC active batch nnz"); + + if (use_bounded_stage) { + // Bounded staging: copy native values/indices into a small pinned + // slot instead of page-locking the whole host CSC. + stage->wait(s); + if (batch_nnz > 0) { + host_copy_slice(h_data, h_indices, (size_t)ptr_start, batch_nnz, + stage->template get<0>(s), + stage->template get<1>(s)); + cudaMemcpyAsync(buf.d_sparse_data_orig, + stage->template get<0>(s), + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, stage->template get<1>(s), + (size_t)batch_nnz * sizeof(IndexT), + cudaMemcpyHostToDevice, stream); + } + stage->record(s, stream); + } else if (batch_nnz > 0) { + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + (size_t)batch_nnz * sizeof(IndexT), + cudaMemcpyHostToDevice, stream); + } + + // Row indices are the sort values; downcast int64 -> int32 at the + // device boundary (values < n_rows < 2^31) so sort + rank stay int32. + int* idx32; + if constexpr (sizeof(IndexT) > sizeof(int)) { + if (batch_nnz > 0) { + int cblk = (batch_nnz + tpb - 1) / tpb; + cast_array_kernel<<>>( + buf.d_sparse_indices, buf.idx_i32, (size_t)batch_nnz); + CUDA_CHECK_LAST_ERROR(cast_array_kernel); + } + idx32 = buf.idx_i32; + } else { + idx32 = buf.d_sparse_indices; + } + + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, idx32, + buf.d_seg_offsets, d_group_codes, buf.d_group_sums, buf.d_group_nnz, + buf.d_total_sums, buf.d_total_nnz, sb_cols, n_groups, compute_nnz, + compute_totals, tpb, smem_cast, cast_use_gmem, stream); + + // Sort only stored nonzeros (float32 keys) + if (batch_nnz > 0) { + cub_segmented_sortpairs(buf.cub_temp, cub_temp_bytes, + buf.d_sparse_data_f32, buf.keys_out, idx32, + buf.vals_out, batch_nnz, sb_cols, + buf.d_seg_offsets, buf.d_seg_offsets + 1, + stream, "host CSC OVR segmented sort"); + } + + launch_ovr_sparse_rank( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, tpb, smem_bytes, compute_tie_corr, + rank_use_gmem, stream); + + scatter_cols_2d(d_rank_sums + col, buf.d_rank_sums, n_groups, n_cols, + sb_cols, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + scatter_cols_2d(d_group_sums + col, buf.d_group_sums, n_groups, n_cols, + sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(d_group_nnz + col, buf.d_group_nnz, n_groups, + n_cols, sb_cols, stream); + } + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, buf.d_total_sums, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, buf.d_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } + + col += sb_cols; + batch_idx++; + } + + sync_streams(streams, "sparse host CSC streaming"); +} + +// Host CSR rowstream OVR for matrices too large to stage on the GPU. +// Sorted rows let cursors advance once, so each nnz is gathered/transferred +// once. +template +static void ovr_sparse_csr_host_rowstream_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, + double* d_total_sums, double* d_total_nnz, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, bool compute_nnz, bool compute_totals, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + size_t total_nnz = (size_t)h_indptr[n_rows]; + + RmmScratchPool pool; + int tpb = UTIL_BLOCK_SIZE; + size_t budget = rmm_available_device_bytes(0.8); + + // Host column histogram; each worker counts privately, then merges. + std::vector h_col_counts(n_cols, 0); + { + int n_workers = host_worker_count(); + std::vector> local(n_workers, + std::vector(n_cols, 0)); + int used = host_parallel_chunks(n_rows, [&](int w, int r0, int r1) { + std::vector& lc = local[w]; + for (IndptrT p = h_indptr[r0]; p < h_indptr[r1]; p++) + lc[(size_t)h_indices[p]]++; + }); + for (int w = 0; w < used; w++) + for (int c = 0; c < n_cols; c++) h_col_counts[c] += local[w][c]; + } + + // Column batches must satisfy int32 CUB limits and device memory budget. + // Per-nnz scratch covers mini-CSR gather, CSC accum, sort output, and CUB. + constexpr size_t BYTES_PER_NNZ = 2 * sizeof(InT) // gather val + csc val + + 2 * sizeof(float) // f32 key in + out + + 3 * sizeof(int) // gather col + 2 rows + + 12; // CUB temp headroom + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = budget / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + ColumnBatchPlan batches = plan_column_batches_from_counts( + n_cols, sub_batch_cols, cap, [&](int c) { return h_col_counts[c]; }, + "rowstream rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + int n_batches = batches.n_batches; + size_t max_batch_nnz = batches.max_nnz; + + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + int mb_i32 = + checked_cub_items(max_batch_nnz, "rowstream sub-batch nnz"); + cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(mb_i32, sub_batch_cols); + } + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); + + // Host gather staging is pinned; full CSR stays pageable on CPU. + // Only the compacted column interval crosses the bus. + size_t stage_nnz = max_batch_nnz ? max_batch_nnz : 1; + PinnedRing gather_stage(1, stage_nnz); + PinnedRing indptr_stage(1, (size_t)n_rows + 1); + std::vector cursor(n_rows, 0); // offset within each (sorted) row + + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + InT* d_gather_vals = pool.alloc(max_batch_nnz); + int* d_gather_cols = pool.alloc(max_batch_nnz); + int* d_gather_indptr = pool.alloc(n_rows + 1); + int* col_offsets = pool.alloc(sub_batch_cols + 1); + int* write_pos = pool.alloc(sub_batch_cols); + int* d_all_offsets = upload_batch_offsets(batches, pool); + InT* csc_vals_orig = pool.alloc(max_batch_nnz); + float* csc_vals_f32 = pool.alloc(max_batch_nnz); + int* csc_row_idx = pool.alloc(max_batch_nnz); + float* keys_out = pool.alloc(max_batch_nnz); + int* vals_out = pool.alloc(max_batch_nnz); + uint8_t* cub_temp = pool.alloc(cub_temp_bytes); + double* sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + double* sub_tie_corr = pool.alloc(sub_batch_cols); + double* sub_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + double* sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + double* sub_total_sums = + compute_totals ? pool.alloc(sub_batch_cols) : nullptr; + double* sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(sub_batch_cols) + : nullptr; + double* d_nz_scratch = + rank_use_gmem ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + ScopedCudaStream row_stream(cudaStreamDefault); + cudaStream_t stream = row_stream.get(); + + // One ascending column pass; sorted-row cursors make transfer one-shot. + // Threaded gather counts row runs, prefix-sums, then copies disjoint + // ranges. + std::vector g_count(n_rows); + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int col_end = col + sb_cols; + gather_stage.wait(0); + indptr_stage.wait(0); + InT* h_gather_vals = gather_stage.template get<0>(0); + int* h_gather_cols = gather_stage.template get<1>(0); + int* h_gather_indptr = indptr_stage.template get<0>(0); + + int batch_nnz = host_materialize_csr_column_interval_cursor( + h_data, h_indices, h_indptr, n_rows, col, col_end, cursor.data(), + g_count.data(), h_gather_indptr, h_gather_vals, h_gather_cols, + "rowstream gather nnz"); + + int* off = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(col_offsets, off, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(write_pos, off, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Bulk H2D of this batch's compacted nonzeros (1x transfer). + if (batch_nnz > 0) { + cuda_check(cudaMemcpyAsync(d_gather_vals, h_gather_vals, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream), + "rowstream gathered vals H2D"); + cuda_check(cudaMemcpyAsync(d_gather_cols, h_gather_cols, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream), + "rowstream gathered cols H2D"); + } + cudaMemcpyAsync(d_gather_indptr, h_gather_indptr, + (n_rows + 1) * sizeof(int), cudaMemcpyHostToDevice, + stream); + gather_stage.record(0, stream); + indptr_stage.record(0, stream); + + // Scatter mini-CSR into the column-batch CSC accumulator. + csr_scatter_to_csc_kernel + <<<(n_rows + tpb - 1) / tpb, tpb, 0, stream>>>( + d_gather_vals, d_gather_cols, d_gather_indptr, write_pos, + csc_vals_orig, csc_row_idx, n_rows, col, col_end, 0); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + + launch_ovr_cast_and_accumulate_sparse( + csc_vals_orig, csc_vals_f32, csc_row_idx, col_offsets, + d_group_codes, sub_group_sums, sub_group_nnz, sub_total_sums, + sub_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals, tpb, + smem_cast, cast_use_gmem, stream); + if (batch_nnz > 0) { + cub_segmented_sortpairs(cub_temp, cub_temp_bytes, csc_vals_f32, + keys_out, csc_row_idx, vals_out, batch_nnz, + sb_cols, col_offsets, col_offsets + 1, + stream, "rowstream segmented sort"); + } + launch_ovr_sparse_rank( + keys_out, vals_out, col_offsets, d_group_codes, d_group_sizes, + sub_rank_sums, sub_tie_corr, d_nz_scratch, n_rows, sb_cols, + n_groups, tpb, smem_bytes, compute_tie_corr, rank_use_gmem, stream); + + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) + cudaMemcpyAsync(d_tie_corr + col, sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + sub_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_nnz) + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + sub_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, sub_total_sums, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } + col += sb_cols; + } + cuda_check(cudaStreamSynchronize(stream), "rowstream sync"); +} + +// Host CSR sparse OVR stream: keep CSR on host and batch CSR->CSC scatter. +// Avoids full sparse upload and whole-matrix CSR->CSC conversion. +template +static void ovr_sparse_csr_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_nnz, + double* d_total_sums, double* d_total_nnz, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, bool compute_nnz, bool compute_totals, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // Declared before pool/streams: on exception unwind streams drain (kernels + // finish reading mapped host memory) before unregistration. + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + + RmmScratchPool pool; + size_t total_nnz = (size_t)h_indptr[n_rows]; + + size_t budget = rmm_available_device_bytes(0.8); + + int tpb = UTIL_BLOCK_SIZE; + size_t data_bytes = total_nnz * sizeof(InT); + size_t idx_bytes = total_nnz * sizeof(IndexT); + + // Too large to stage on device: per-batch scatter would fall back to + // bus-latency-bound zero-copy reads. Page the CSR through in row blocks. + if (total_nnz > 0 && data_bytes + idx_bytes > (budget * 3) / 4) { + ovr_sparse_csr_host_rowstream_impl( + h_data, h_indices, h_indptr, h_group_codes, h_group_sizes, + d_rank_sums, d_tie_corr, d_group_sums, d_group_nnz, d_total_sums, + d_total_nnz, n_rows, n_cols, n_groups, compute_tie_corr, + compute_nnz, compute_totals, sub_batch_cols); + return; + } + + IndptrT* d_indptr_full = pool.alloc(n_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + // Stage indices first when they fit so histogram/scatter read at HBM speed. + // Data is staged too only if data plus one stream buffer still fits. + IndexT* d_indices = nullptr; + bool indices_staged = total_nnz > 0 && idx_bytes <= budget / 2; + if (total_nnz > 0) { + if (indices_staged) { + d_indices = pool.alloc(total_nnz); + cuda_check(cudaMemcpy(d_indices, h_indices, idx_bytes, + cudaMemcpyHostToDevice), + "OVR host CSR stage indices H2D"); + } else { + pin_indices = HostRegisterGuard(const_cast(h_indices), + idx_bytes, cudaHostRegisterMapped); + cuda_check( + cudaHostGetDevicePointer((void**)&d_indices, + const_cast(h_indices), 0), + "OVR host CSR map indices"); + } + } + + // Count per-column nnz on GPU; CSR has no native column structure. + // Only n_cols counts are copied back for batch planning. + std::vector h_col_counts(n_cols, 0); + if (total_nnz > 0) { + unsigned int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(unsigned int)); + int hist_blocks = (n_rows + tpb - 1) / tpb; + csr_col_histogram_kernel<<>>( + d_indices, d_indptr_full, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); + cuda_check( + cudaMemcpy(h_col_counts.data(), d_col_counts, + n_cols * sizeof(unsigned int), cudaMemcpyDeviceToHost), + "OVR host CSR column-count D2H"); + } + + // Each batch uses one CUB segmented sort and per-stream CSR->CSC scratch. + // Shrink sub_batch_cols until item counts and memory budget both fit. + constexpr size_t BYTES_PER_NNZ = sizeof(InT) + sizeof(float) + + 2 * sizeof(int) + 8; // buffers + CUB temp + size_t batch_nnz_cap = SAFE_BATCH_NNZ; + size_t mem_cap = budget / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < batch_nnz_cap) batch_nnz_cap = mem_cap; + ColumnBatchPlan batches = plan_column_batches_from_counts( + n_cols, sub_batch_cols, batch_nnz_cap, + [&](int c) { return (size_t)h_col_counts[c]; }, + "OVR host CSR rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + int n_batches = batches.n_batches; + size_t max_batch_nnz = batches.max_nnz; + int* d_all_offsets = upload_batch_offsets(batches, pool); + + // ---- Phase 1: per-stream bounded work buffer size + stream count ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + int max_batch_nnz_i32 = checked_cub_items( + max_batch_nnz, "OVR host CSR sparse sub-batch nnz"); + cub_temp_bytes = cub_segmented_sortpairs_temp_bytes(max_batch_nnz_i32, + sub_batch_cols); + } + + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); + + size_t per_stream_bytes = + max_batch_nnz * (sizeof(InT) + sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + if (compute_nnz) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (compute_totals) { + per_stream_bytes += sub_batch_cols * sizeof(double); + if (compute_nnz) { + per_stream_bytes += sub_batch_cols * sizeof(double); + } + } + if (rank_use_gmem) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + + // Stage data when indices are resident and one transpose stream still fits. + // Otherwise values stay mapped zero-copy for bounded-memory streaming. + size_t resident = indices_staged ? idx_bytes : 0; + bool data_staged = total_nnz > 0 && indices_staged && + resident + data_bytes + per_stream_bytes <= budget; + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + size_t stream_budget = budget - resident - (data_staged ? data_bytes : 0); + n_streams = + clamp_streams_by_budget(n_streams, per_stream_bytes, stream_budget); + + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + InT* d_data = nullptr; + if (total_nnz > 0) { + if (data_staged) { + d_data = pool.alloc(total_nnz); + cuda_check( + cudaMemcpy(d_data, h_data, data_bytes, cudaMemcpyHostToDevice), + "OVR host CSR stage data H2D"); + } else { + pin_data = HostRegisterGuard(const_cast(h_data), data_bytes, + cudaHostRegisterMapped); + cuda_check(cudaHostGetDevicePointer((void**)&d_data, + const_cast(h_data), 0), + "OVR host CSR map data"); + } + } + + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; + int* write_pos; + InT* csc_vals_orig; + float* csc_vals_f32; + int* csc_row_idx; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_nnz; + double* sub_total_sums; + double* sub_total_nnz; + double* d_nz_scratch; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals_orig = pool.alloc(max_batch_nnz); + bufs[s].csc_vals_f32 = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_total_sums = + compute_totals ? pool.alloc(sub_batch_cols) : nullptr; + bufs[s].sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: bounded CSR->CSC scatter + GPU rank batches ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = + checked_int_span(batches.nnz[b], "OVR host CSR active batch nnz"); + + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + csr_scatter_to_csc_kernel + <<>>( + d_data, d_indices, d_indptr_full, buf.write_pos, + buf.csc_vals_orig, buf.csc_row_idx, n_rows, col, + col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + } + + launch_ovr_cast_and_accumulate_sparse( + buf.csc_vals_orig, buf.csc_vals_f32, buf.csc_row_idx, + buf.col_offsets, d_group_codes, buf.sub_group_sums, + buf.sub_group_nnz, buf.sub_total_sums, buf.sub_total_nnz, sb_cols, + n_groups, compute_nnz, compute_totals, tpb, smem_cast, + cast_use_gmem, stream); + + if (batch_nnz > 0) { + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, buf.csc_vals_f32, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, stream, + "host CSR OVR segmented sort"); + } + + launch_ovr_sparse_rank( + buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, + d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, tpb, smem_bytes, + compute_tie_corr, rank_use_gmem, stream); + + scatter_cols_2d(d_rank_sums + col, buf.sub_rank_sums, n_groups, n_cols, + sb_cols, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + scatter_cols_2d(d_group_sums + col, buf.sub_group_sums, n_groups, + n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(d_group_nnz + col, buf.sub_group_nnz, n_groups, + n_cols, sb_cols, stream); + } + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, buf.sub_total_sums, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, buf.sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } + + col += sb_cols; + } + + sync_streams(streams, "sparse host CSR streaming"); +} + +// Sign-safe sparse OVR path: sparse window -> dense f32 tile -> dense rank. + +constexpr int SPARSE_DENSE_OVR_CHUNK_COLS = 512; + +static void launch_ovr_dense_rank_window( + const float* dense, const int* group_codes, double* rank_sums, + double* tie_corr, int out_col, int n_rows, int n_cols_total, + int window_cols, int n_groups, bool compute_tie_corr, + int rank_sub_batch_cols, cudaStream_t upstream_stream) { + if (n_rows == 0 || window_cols == 0 || n_groups == 0) return; + if (rank_sub_batch_cols <= 0) rank_sub_batch_cols = SUB_BATCH_COLS; + + DenseColumnBatchPlan batches = plan_dense_column_batches( + n_rows, window_cols, rank_sub_batch_cols, SAFE_BATCH_NNZ, + "Sparse-dense OVR rank sub-batch"); + rank_sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(window_cols, rank_sub_batch_cols); + size_t sub_items = batches.max_items; + int sub_items_i32 = + checked_cub_items(sub_items, "Sparse-dense OVR rank sub-batch"); + size_t cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(sub_items_i32, rank_sub_batch_cols); + + RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamNonBlocking); + ScopedCudaEvent inputs_ready(cudaEventDisableTiming); + inputs_ready.record(upstream_stream); + for (int s = 0; s < n_streams; s++) { + cuda_check(cudaStreamWaitEvent(streams[s], inputs_ready.get(), 0), + "wait on sparse-dense OVR tile"); + } + + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(rank_sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * rank_sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(rank_sub_batch_cols); + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + int col = 0; + int batch_idx = 0; + while (col < window_cols) { + int sb_cols = std::min(rank_sub_batch_cols, window_cols - col); + int sb_items = + checked_int_product((size_t)n_rows, (size_t)sb_cols, + "Sparse-dense OVR active rank sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + const float* keys_in = dense + (size_t)col * n_rows; + cub_segmented_sortpairs( + buf.cub_temp, cub_temp_bytes, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, stream, "sparse-dense OVR segmented sort"); + + if (use_gmem) { + cuda_check(cudaMemsetAsync( + buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream), + "sparse-dense OVR gmem rank_sums memset"); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + scatter_cols_2d(rank_sums + out_col + col, buf.sub_rank_sums, n_groups, + n_cols_total, sb_cols, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + out_col + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + sync_streams(streams, "sparse-dense OVR rank"); +} + +template +static void ovr_dense_csr_streaming_impl( + const float* csr_data, const IndexT* csr_indices, const IndptrT* csr_indptr, + const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, int chunk_cols, + int rank_sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; + + DenseColumnBatchPlan chunks = + plan_dense_column_batches(n_rows, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Device CSR sparse-dense OVR column chunk"); + chunk_cols = chunks.sub_batch_cols; + size_t max_dense_items = (size_t)n_rows * (size_t)chunk_cols; + + RmmScratchPool pool; + ScopedCudaStream extract_stream(cudaStreamDefault); + cudaStream_t stream = extract_stream.get(); + float* dense = pool.alloc(max_dense_items); + + for (int col = 0; col < n_cols; col += chunk_cols) { + int sb_cols = std::min(chunk_cols, n_cols - col); + size_t dense_items = (size_t)n_rows * (size_t)sb_cols; + cudaMemsetAsync(dense, 0, dense_items * sizeof(float), stream); + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_tile_to_dense_kernel + <<>>(csr_indptr, csr_indices, + csr_data, dense, col, + col + sb_cols, n_rows); + CUDA_CHECK_LAST_ERROR(csr_tile_to_dense_kernel); + launch_ovr_dense_rank_window( + dense, group_codes, rank_sums, tie_corr, col, n_rows, n_cols, + sb_cols, n_groups, compute_tie_corr, rank_sub_batch_cols, stream); + } +} + +template +static void ovr_dense_csc_streaming_impl( + const float* csc_data, const IndexT* csc_indices, const IndptrT* csc_indptr, + const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, int chunk_cols, + int rank_sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; + + DenseColumnBatchPlan chunks = + plan_dense_column_batches(n_rows, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Device CSC sparse-dense OVR column chunk"); + chunk_cols = chunks.sub_batch_cols; + size_t max_dense_items = (size_t)n_rows * (size_t)chunk_cols; + + RmmScratchPool pool; + ScopedCudaStream extract_stream(cudaStreamDefault); + cudaStream_t stream = extract_stream.get(); + float* dense = pool.alloc(max_dense_items); + + for (int col = 0; col < n_cols; col += chunk_cols) { + int sb_cols = std::min(chunk_cols, n_cols - col); + size_t dense_items = (size_t)n_rows * (size_t)sb_cols; + cudaMemsetAsync(dense, 0, dense_items * sizeof(float), stream); + csc_tile_to_dense_kernel + <<>>(csc_indptr, csc_indices, + csc_data, dense, col, + col + sb_cols, n_rows); + CUDA_CHECK_LAST_ERROR(csc_tile_to_dense_kernel); + launch_ovr_dense_rank_window( + dense, group_codes, rank_sums, tie_corr, col, n_rows, n_cols, + sb_cols, n_groups, compute_tie_corr, rank_sub_batch_cols, stream); + } +} + +template +static void ovr_dense_csc_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, double* d_rank_sums, double* d_tie_corr, + double* d_group_sums, double* d_group_nnz, double* d_total_sums, + double* d_total_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_stats, bool compute_nnz, + bool compute_totals, int chunk_cols, int rank_sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; + compute_nnz = compute_stats && compute_nnz && d_group_nnz != nullptr; + compute_totals = compute_stats && compute_totals && d_total_sums != nullptr; + + DenseColumnBatchPlan dense_chunks = + plan_dense_column_batches(n_rows, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Host CSC sparse-dense OVR column chunk"); + chunk_cols = dense_chunks.sub_batch_cols; + ColumnBatchPlan batches = + plan_csc_column_batches(h_indptr, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Host CSC sparse-dense OVR offsets"); + chunk_cols = batches.sub_batch_cols; + size_t max_nnz = batches.max_nnz; + size_t max_dense_items = (size_t)n_rows * (size_t)chunk_cols; + + RmmScratchPool pool; + PinnedRing stage(1, max_nnz ? max_nnz : 1); + ScopedCudaStream extract_stream(cudaStreamDefault); + cudaStream_t stream = extract_stream.get(); + + int* d_group_codes = pool.alloc(n_rows); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + int* d_all_offsets = upload_batch_offsets(batches, pool); + + InT* d_sparse_data_orig = pool.alloc(max_nnz ? max_nnz : 1); + float* d_sparse_data_f32 = pool.alloc(max_nnz ? max_nnz : 1); + IndexT* d_sparse_indices = pool.alloc(max_nnz ? max_nnz : 1); + int* idx_i32 = (sizeof(IndexT) > sizeof(int)) + ? pool.alloc(max_nnz ? max_nnz : 1) + : nullptr; + int* d_indptr = pool.alloc(chunk_cols + 1); + float* dense = pool.alloc(max_dense_items); + double* sub_group_sums = + compute_stats ? pool.alloc((size_t)n_groups * chunk_cols) + : nullptr; + double* sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * chunk_cols) + : nullptr; + double* sub_total_sums = + compute_totals ? pool.alloc(chunk_cols) : nullptr; + double* sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(chunk_cols) + : nullptr; + + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); + + int col = 0; + for (int b = 0; b < batches.n_batches; b++) { + int sb_cols = std::min(chunk_cols, n_cols - col); + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "Host CSC sparse-dense active nnz"); + + stage.wait(0); + if (batch_nnz > 0) { + host_copy_slice(h_data, h_indices, (size_t)ptr_start, batch_nnz, + stage.template get<0>(0), stage.template get<1>(0)); + cudaMemcpyAsync(d_sparse_data_orig, stage.template get<0>(0), + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_sparse_indices, stage.template get<1>(0), + (size_t)batch_nnz * sizeof(IndexT), + cudaMemcpyHostToDevice, stream); + } + stage.record(0, stream); + + int* idx32; + if constexpr (sizeof(IndexT) > sizeof(int)) { + if (batch_nnz > 0) { + int cblk = (batch_nnz + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + cast_array_kernel + <<>>( + d_sparse_indices, idx_i32, (size_t)batch_nnz); + CUDA_CHECK_LAST_ERROR(cast_array_kernel); + } + idx32 = idx_i32; + } else { + idx32 = reinterpret_cast(d_sparse_indices); + } + + int* src = d_all_offsets + (size_t)b * (chunk_cols + 1); + cudaMemcpyAsync(d_indptr, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + launch_ovr_cast_and_accumulate_sparse( + d_sparse_data_orig, d_sparse_data_f32, idx32, d_indptr, + d_group_codes, sub_group_sums, sub_group_nnz, sub_total_sums, + sub_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals, + UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + + size_t dense_items = (size_t)n_rows * (size_t)sb_cols; + cudaMemsetAsync(dense, 0, dense_items * sizeof(float), stream); + csc_tile_to_dense_kernel + <<>>( + d_indptr, idx32, d_sparse_data_f32, dense, 0, sb_cols, n_rows); + CUDA_CHECK_LAST_ERROR(csc_tile_to_dense_kernel); + + if (compute_stats) { + scatter_cols_2d(d_group_sums + col, sub_group_sums, n_groups, + n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(d_group_nnz + col, sub_group_nnz, n_groups, + n_cols, sb_cols, stream); + } + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, sub_total_sums, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } + } + + launch_ovr_dense_rank_window( + dense, d_group_codes, d_rank_sums, d_tie_corr, col, n_rows, n_cols, + sb_cols, n_groups, compute_tie_corr, rank_sub_batch_cols, stream); + col += sb_cols; + } +} + +template +static void ovr_dense_csr_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, double* d_rank_sums, double* d_tie_corr, + double* d_group_sums, double* d_group_nnz, double* d_total_sums, + double* d_total_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_stats, bool compute_nnz, + bool compute_totals, int chunk_cols, int rank_sub_batch_cols) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; + compute_nnz = compute_stats && compute_nnz && d_group_nnz != nullptr; + compute_totals = compute_stats && compute_totals && d_total_sums != nullptr; + + DenseColumnBatchPlan dense_chunks = + plan_dense_column_batches(n_rows, n_cols, chunk_cols, SAFE_BATCH_NNZ, + "Host CSR sparse-dense OVR column chunk"); + chunk_cols = dense_chunks.sub_batch_cols; + + std::vector h_col_counts(n_cols, 0); + { + int n_workers = host_worker_count(); + std::vector> local(n_workers, + std::vector(n_cols, 0)); + int used = host_parallel_chunks(n_rows, [&](int w, int r0, int r1) { + std::vector& lc = local[w]; + for (IndptrT p = h_indptr[r0]; p < h_indptr[r1]; p++) + lc[(size_t)h_indices[p]]++; + }); + for (int w = 0; w < used; w++) + for (int c = 0; c < n_cols; c++) h_col_counts[c] += local[w][c]; + } + + size_t cap = SAFE_BATCH_NNZ; + ColumnBatchPlan batches = plan_column_batches_from_counts( + n_cols, chunk_cols, cap, [&](int c) { return h_col_counts[c]; }, + "Host CSR sparse-dense OVR offsets"); + chunk_cols = batches.sub_batch_cols; + size_t max_batch_nnz = batches.max_nnz; + size_t max_dense_items = (size_t)n_rows * (size_t)chunk_cols; + + RmmScratchPool pool; + ScopedCudaStream extract_stream(cudaStreamDefault); + cudaStream_t stream = extract_stream.get(); + PinnedRing gather_stage(1, max_batch_nnz ? max_batch_nnz : 1); + PinnedRing indptr_stage(1, (size_t)n_rows + 1); + std::vector cursor(n_rows, 0); + std::vector row_counts(n_rows); + + int* d_group_codes = pool.alloc(n_rows); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + int* d_all_offsets = upload_batch_offsets(batches, pool); + + InT* d_gather_vals = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + int* d_gather_cols = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + int* d_gather_indptr = pool.alloc(n_rows + 1); + int* col_offsets = pool.alloc(chunk_cols + 1); + int* write_pos = pool.alloc(chunk_cols); + InT* csc_vals_orig = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + float* csc_vals_f32 = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + int* csc_row_idx = pool.alloc(max_batch_nnz ? max_batch_nnz : 1); + float* dense = pool.alloc(max_dense_items); + double* sub_group_sums = + compute_stats ? pool.alloc((size_t)n_groups * chunk_cols) + : nullptr; + double* sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * chunk_cols) + : nullptr; + double* sub_total_sums = + compute_totals ? pool.alloc(chunk_cols) : nullptr; + double* sub_total_nnz = (compute_totals && compute_nnz) + ? pool.alloc(chunk_cols) + : nullptr; + + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups, compute_nnz, compute_totals, cast_use_gmem); + + int col = 0; + for (int b = 0; b < batches.n_batches; b++) { + int sb_cols = std::min(chunk_cols, n_cols - col); + int col_end = col + sb_cols; + gather_stage.wait(0); + indptr_stage.wait(0); + InT* h_gather_vals = gather_stage.template get<0>(0); + int* h_gather_cols = gather_stage.template get<1>(0); + int* h_gather_indptr = indptr_stage.template get<0>(0); + + int batch_nnz = host_materialize_csr_column_interval_cursor( + h_data, h_indices, h_indptr, n_rows, col, col_end, cursor.data(), + row_counts.data(), h_gather_indptr, h_gather_vals, h_gather_cols, + "Host CSR sparse-dense gather nnz"); + + int* src = d_all_offsets + (size_t)b * (chunk_cols + 1); + cudaMemcpyAsync(col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + cudaMemcpyAsync(d_gather_vals, h_gather_vals, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_gather_cols, h_gather_cols, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + cudaMemcpyAsync(d_gather_indptr, h_gather_indptr, + (n_rows + 1) * sizeof(int), cudaMemcpyHostToDevice, + stream); + gather_stage.record(0, stream); + indptr_stage.record(0, stream); + + if (batch_nnz > 0) { + csr_scatter_to_csc_kernel + <<<(n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE, + UTIL_BLOCK_SIZE, 0, stream>>>( + d_gather_vals, d_gather_cols, d_gather_indptr, write_pos, + csc_vals_orig, csc_row_idx, n_rows, col, col_end, 0); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + } + + launch_ovr_cast_and_accumulate_sparse( + csc_vals_orig, csc_vals_f32, csc_row_idx, col_offsets, + d_group_codes, sub_group_sums, sub_group_nnz, sub_total_sums, + sub_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals, + UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + + size_t dense_items = (size_t)n_rows * (size_t)sb_cols; + cudaMemsetAsync(dense, 0, dense_items * sizeof(float), stream); + csc_tile_to_dense_kernel + <<>>(col_offsets, csc_row_idx, + csc_vals_f32, dense, 0, + sb_cols, n_rows); + CUDA_CHECK_LAST_ERROR(csc_tile_to_dense_kernel); + + if (compute_stats) { + scatter_cols_2d(d_group_sums + col, sub_group_sums, n_groups, + n_cols, sb_cols, stream); + if (compute_nnz) { + scatter_cols_2d(d_group_nnz + col, sub_group_nnz, n_groups, + n_cols, sb_cols, stream); + } + if (compute_totals) { + cudaMemcpyAsync(d_total_sums + col, sub_total_sums, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + if (compute_nnz) { + cudaMemcpyAsync(d_total_nnz + col, sub_total_nnz, + sb_cols * sizeof(double), + cudaMemcpyDeviceToDevice, stream); + } + } + } + + launch_ovr_dense_rank_window( + dense, d_group_codes, d_rank_sums, d_tie_corr, col, n_rows, n_cols, + sb_cols, n_groups, compute_tie_corr, rank_sub_batch_cols, stream); + col += sb_cols; + } +} + +// Sparse-aware CSC OVR streaming: sort only stored nonzeros. + +template +static void ovr_sparse_csc_streaming_impl( + const float* csc_data, const IndexT* csc_indices, const IndptrT* csc_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // Read indptr to host for batch planning. + std::vector h_indptr(n_cols + 1); + cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(IndptrT), + cudaMemcpyDeviceToHost); + + // Bound each batch's nnz: CUB item counts within int32 + sort buffers fit. + constexpr size_t BYTES_PER_NNZ = 2 * sizeof(float) + 2 * sizeof(int) + 8; + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = + rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + ColumnBatchPlan batches = + plan_csc_column_batches(h_indptr.data(), n_cols, sub_batch_cols, cap, + "OVR device CSC rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + int n_streams = clamp_streams_by_cols(n_cols, sub_batch_cols); + size_t max_nnz = batches.max_nnz; + + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + int max_nnz_i32 = + checked_cub_items(max_nnz, "OVR device CSC sparse sub-batch nnz"); + cub_temp_bytes = + cub_segmented_sortpairs_temp_bytes(max_nnz_i32, sub_batch_cols); + } + + // pool first: streams drain before it frees their scratch (see guard doc). + RmmScratchPool pool; + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + + struct StreamBuf { + float* keys_out; + int* vals_out; + int* idx_i32; // int32 sort-val scratch; only used when IndexT != int + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].idx_i32 = + (sizeof(IndexT) > sizeof(int)) ? pool.alloc(max_nnz) : nullptr; + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "OVR device CSC active batch nnz"); + + // Rebase segment offsets on GPU (avoids host pinned-buffer race). + { + int count = sb_cols + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + csc_indptr, buf.seg_offsets, col, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Sort stored values; row indices become int32 sort values here. + // This keeps sort/rank int32 while preserving int64 sparse buffers. + if (batch_nnz > 0) { + const int* idx_src; + if constexpr (sizeof(IndexT) > sizeof(int)) { + int cblk = (batch_nnz + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + cast_array_kernel + <<>>( + csc_indices + ptr_start, buf.idx_i32, + (size_t)batch_nnz); + CUDA_CHECK_LAST_ERROR(cast_array_kernel); + idx_src = buf.idx_i32; + } else { + idx_src = csc_indices + ptr_start; + } + cub_segmented_sortpairs(buf.cub_temp, cub_temp_bytes, + csc_data + ptr_start, buf.keys_out, idx_src, + buf.vals_out, batch_nnz, sb_cols, + buf.seg_offsets, buf.seg_offsets + 1, + stream, "device CSC OVR segmented sort"); + } + + // Sparse rank kernel (handles implicit zeros analytically) + launch_ovr_sparse_rank(buf.keys_out, buf.vals_out, buf.seg_offsets, + group_codes, group_sizes, buf.sub_rank_sums, + buf.sub_tie_corr, buf.d_nz_scratch, n_rows, + sb_cols, n_groups, tpb, smem_bytes, + compute_tie_corr, rank_use_gmem, stream); + + scatter_cols_2d(rank_sums + col, buf.sub_rank_sums, n_groups, n_cols, + sb_cols, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + sync_streams(streams, "sparse ovr streaming"); +} + +// Sparse-aware CSR OVR streaming with partial CSR->CSC transpose per batch. +// Histogram plans batches; each batch transposes, sorts nnz only, then ranks. +template +static void ovr_sparse_csr_streaming_impl( + const float* csr_data, const IndexT* csr_indices, const IndptrT* csr_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // ---- Phase 0: count nnz per column via histogram ---- + RmmScratchPool pool; + unsigned int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(unsigned int)); + { + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_col_histogram_kernel<<>>( + csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); + } + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(unsigned int), + cudaMemcpyDeviceToHost); + + // Bound each batch's nnz: CUB item counts within int32 + transpose/sort + // buffers fit. + constexpr size_t BYTES_PER_NNZ = 2 * sizeof(float) + 2 * sizeof(int) + 8; + size_t cap = SAFE_BATCH_NNZ; + size_t mem_cap = + rmm_available_device_bytes(0.8) / (size_t)N_STREAMS / BYTES_PER_NNZ; + if (mem_cap > 0 && mem_cap < cap) cap = mem_cap; + ColumnBatchPlan batches = plan_column_batches_from_counts( + n_cols, sub_batch_cols, cap, + [&](int c) { return (size_t)h_col_counts[c]; }, + "OVR device CSR rebased column offsets"); + sub_batch_cols = batches.sub_batch_cols; + int n_batches = batches.n_batches; + size_t max_batch_nnz = batches.max_nnz; + + // Upload all batch offsets in one H2D. + int* d_all_offsets = upload_batch_offsets(batches, pool); + + // ---- Phase 1: per-stream buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + int max_batch_nnz_i32 = checked_cub_items( + max_batch_nnz, "OVR device CSR sparse sub-batch nnz"); + cub_temp_bytes = cub_segmented_sortpairs_temp_bytes(max_batch_nnz_i32, + sub_batch_cols); + } + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + // CSR path needs 4 sort arrays per stream (scatter intermediates + CUB + // output); fit stream count to available GPU memory. + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + size_t per_stream_bytes = + max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + if (rank_use_gmem) { + // gmem fallback (n_groups too large for smem): per-stream d_nz_scratch, + // same size as sub_rank_sums. + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + + size_t budget = rmm_available_device_bytes(0.8); + n_streams = clamp_streams_by_budget(n_streams, per_stream_bytes, budget); + + ScopedCudaStreams streams(n_streams, cudaStreamDefault); + + int tpb = UTIL_BLOCK_SIZE; + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; // CSC-style offsets + int* write_pos; // atomic write counters + float* csc_vals; // transposed values + int* csc_row_idx; // transposed row indices + float* keys_out; // CUB sort output + int* vals_out; // CUB sort output + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: stream loop ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = + checked_int_span(batches.nnz[b], "OVR device CSR active batch nnz"); + + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // write_pos = col_offsets[0..sb_cols-1] (same D2D source). + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + // Scatter CSR -> CSC for this sub-batch. + csr_scatter_to_csc_kernel<<>>( + csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, + buf.csc_row_idx, n_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + + // Sort only the nonzeros. + cub_segmented_sortpairs(buf.cub_temp, cub_temp_bytes, buf.csc_vals, + buf.keys_out, buf.csc_row_idx, buf.vals_out, + batch_nnz, sb_cols, buf.col_offsets, + buf.col_offsets + 1, stream, + "device CSR OVR segmented sort"); + } + + // Sparse rank kernel (handles implicit zeros analytically) + launch_ovr_sparse_rank(buf.keys_out, buf.vals_out, buf.col_offsets, + group_codes, group_sizes, buf.sub_rank_sums, + buf.sub_tie_corr, buf.d_nz_scratch, n_rows, + sb_cols, n_groups, tpb, smem_bytes, + compute_tie_corr, rank_use_gmem, stream); + + scatter_cols_2d(rank_sums + col, buf.sub_rank_sums, n_groups, n_cols, + sb_cols, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + } + + sync_streams(streams, "sparse CSR ovr streaming"); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh new file mode 100644 index 00000000..d02103cc --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_tie_walk.cuh @@ -0,0 +1,69 @@ +#pragma once + +#include + +// Walk one sorted-column chunk and accumulate tie-averaged ranks atomically. +// Boundary ties are expanded by search; sparse paths pass a rank_offset. +template +__device__ __forceinline__ double ovr_walk_tie_runs( + const float* sv, const IndexT* si, const int* group_codes, double* grp_sums, + int acc_stride, int n_groups, int my_start, int my_end, int seg_floor, + int seg_ceil, double rank_offset, bool compute_tie_corr) { + double local_tie_sum = 0.0; + int i = my_start; + while (i < my_end) { + float val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > seg_floor && sv[i - 1] == val) { + // tie spans into a prior chunk: find global tie start. + int lo = seg_floor, hi = i; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < seg_ceil && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = seg_ceil - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + double avg_rank = + rank_offset + + ((double)tie_global_start + (double)tie_global_end + 1.0) / 2.0; + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp >= 0 && grp < n_groups) { + atomicAdd(&grp_sums[(size_t)grp * acc_stride], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + return local_tie_sum; +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu new file mode 100644 index 00000000..e3a1032c --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -0,0 +1,342 @@ +#include +#include + +#include + +#include "../nb_types.h" +#include "wilcoxon_fast_common.cuh" +#include "kernels_wilcoxon.cuh" +#include "wilcoxon_sparse_kernels.cuh" +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovr_sparse.cuh" +#include "kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_ovo_kernels.cuh" +#include "wilcoxon_ovo_device_sparse.cuh" +#include "wilcoxon_ovo_host_sparse.cuh" + +using namespace nb::literals; + +template +void register_sparse_bindings(nb::module_& m) { + m.doc() = "Sparse-native host Wilcoxon CUDA kernels"; + +#define RSC_OVR_SPARSE_DEVICE_BINDING(NAME, IMPL, IndexCType, IndptrCType) \ + m.def( \ + NAME, \ + [](gpu_array_c data, \ + gpu_array_c indices, \ + gpu_array_c indptr, \ + gpu_array_c group_codes, \ + gpu_array_c group_sizes, \ + gpu_array_c rank_sums, \ + gpu_array_c tie_corr, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ + IMPL(data.data(), indices.data(), indptr.data(), \ + group_codes.data(), group_sizes.data(), rank_sums.data(), \ + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, \ + sub_batch_cols); \ + }, \ + "data"_a, "indices"_a, "indptr"_a, "group_codes"_a, "group_sizes"_a, \ + "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, \ + "n_groups"_a, "compute_tie_corr"_a, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device", + ovr_sparse_csc_streaming_impl, int, int); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csc_device", + ovr_sparse_csc_streaming_impl, int64_t, + int64_t); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device", + ovr_sparse_csr_streaming_impl, int, int); + RSC_OVR_SPARSE_DEVICE_BINDING("ovr_sparse_csr_device", + ovr_sparse_csr_streaming_impl, int64_t, + int64_t); +#undef RSC_OVR_SPARSE_DEVICE_BINDING + +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_nnz, \ + gpu_array_c d_total_sums, \ + gpu_array_c d_total_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_nnz, \ + bool compute_totals, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ + ovr_sparse_csc_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_nnz.data(), d_total_sums.data(), d_total_nnz.data(), \ + n_rows, n_cols, n_groups, compute_tie_corr, compute_nnz, \ + compute_totals, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_nnz"_a, "d_total_sums"_a, "d_total_nnz"_a, nb::kw_only(), \ + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_nnz"_a = true, "compute_totals"_a = false, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", double, int, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int64_t, + int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", double, int64_t, + int64_t); +#undef RSC_OVR_SPARSE_CSC_HOST_BINDING + +#define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_nnz, \ + gpu_array_c d_total_sums, \ + gpu_array_c d_total_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_nnz, \ + bool compute_totals, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ + ovr_sparse_csr_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_nnz.data(), d_total_sums.data(), d_total_nnz.data(), \ + n_rows, n_cols, n_groups, compute_tie_corr, compute_nnz, \ + compute_totals, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_nnz"_a, "d_total_sums"_a, "d_total_nnz"_a, nb::kw_only(), \ + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_nnz"_a = true, "compute_totals"_a = false, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", double, int, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int64_t, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", double, int64_t, + int64_t); +#undef RSC_OVR_SPARSE_CSR_HOST_BINDING + +#define RSC_OVR_DENSE_DEVICE_BINDING(NAME, IMPL, IndexCType, IndptrCType) \ + m.def( \ + NAME, \ + [](gpu_array_c data, \ + gpu_array_c indices, \ + gpu_array_c indptr, \ + gpu_array_c group_codes, \ + gpu_array_c rank_sums, \ + gpu_array_c tie_corr, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, int chunk_cols, \ + int rank_sub_batch_cols) { \ + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; \ + if (rank_sub_batch_cols <= 0) \ + rank_sub_batch_cols = SUB_BATCH_COLS; \ + IMPL(data.data(), indices.data(), indptr.data(), \ + group_codes.data(), rank_sums.data(), tie_corr.data(), \ + n_rows, n_cols, n_groups, compute_tie_corr, chunk_cols, \ + rank_sub_batch_cols); \ + }, \ + "data"_a, "indices"_a, "indptr"_a, "group_codes"_a, "rank_sums"_a, \ + "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "compute_tie_corr"_a, "chunk_cols"_a = SPARSE_DENSE_OVR_CHUNK_COLS, \ + "rank_sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_DENSE_DEVICE_BINDING("ovr_dense_csc_device", + ovr_dense_csc_streaming_impl, int, int); + RSC_OVR_DENSE_DEVICE_BINDING( + "ovr_dense_csc_device", ovr_dense_csc_streaming_impl, int64_t, int64_t); + RSC_OVR_DENSE_DEVICE_BINDING("ovr_dense_csr_device", + ovr_dense_csr_streaming_impl, int, int); + RSC_OVR_DENSE_DEVICE_BINDING( + "ovr_dense_csr_device", ovr_dense_csr_streaming_impl, int64_t, int64_t); +#undef RSC_OVR_DENSE_DEVICE_BINDING + +#define RSC_OVR_DENSE_HOST_BINDING(NAME, IMPL, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_nnz, \ + gpu_array_c d_total_sums, \ + gpu_array_c d_total_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_stats, \ + bool compute_nnz, bool compute_totals, int chunk_cols, \ + int rank_sub_batch_cols) { \ + if (chunk_cols <= 0) chunk_cols = SPARSE_DENSE_OVR_CHUNK_COLS; \ + if (rank_sub_batch_cols <= 0) \ + rank_sub_batch_cols = SUB_BATCH_COLS; \ + IMPL(h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), d_rank_sums.data(), d_tie_corr.data(), \ + d_group_sums.data(), d_group_nnz.data(), d_total_sums.data(), \ + d_total_nnz.data(), n_rows, n_cols, n_groups, \ + compute_tie_corr, compute_stats, compute_nnz, compute_totals, \ + chunk_cols, rank_sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, "d_group_nnz"_a, \ + "d_total_sums"_a, "d_total_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "compute_stats"_a, \ + "compute_nnz"_a = true, "compute_totals"_a = false, \ + "chunk_cols"_a = SPARSE_DENSE_OVR_CHUNK_COLS, \ + "rank_sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csc_host", + ovr_dense_csc_host_streaming_impl, float, int, + int); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csc_host", + ovr_dense_csc_host_streaming_impl, double, int, + int); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csc_host", + ovr_dense_csc_host_streaming_impl, float, + int64_t, int64_t); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csc_host", + ovr_dense_csc_host_streaming_impl, double, + int64_t, int64_t); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csr_host", + ovr_dense_csr_host_streaming_impl, float, int, + int); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csr_host", + ovr_dense_csr_host_streaming_impl, double, int, + int); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csr_host", + ovr_dense_csr_host_streaming_impl, float, + int64_t, int64_t); + RSC_OVR_DENSE_HOST_BINDING("ovr_dense_csr_host", + ovr_dense_csr_host_streaming_impl, double, + int64_t, int64_t); +#undef RSC_OVR_DENSE_HOST_BINDING + +#define RSC_OVO_DEVICE_BINDING(NAME, IMPL, IndexCType, IndptrCType) \ + m.def( \ + NAME, \ + [](gpu_array_c data, \ + gpu_array_c indices, \ + gpu_array_c indptr, \ + gpu_array_c ref_rows, \ + gpu_array_c grp_rows, \ + gpu_array_c grp_offsets, \ + gpu_array_c rank_sums, \ + gpu_array_c tie_corr, int n_ref, int n_all_grp, \ + int n_cols, int n_groups, bool compute_tie_corr, \ + int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ + IMPL(data.data(), indices.data(), indptr.data(), ref_rows.data(), \ + grp_rows.data(), grp_offsets.data(), rank_sums.data(), \ + tie_corr.data(), n_ref, n_all_grp, n_cols, n_groups, \ + compute_tie_corr, sub_batch_cols); \ + }, \ + "data"_a, "indices"_a, "indptr"_a, "ref_rows"_a, "grp_rows"_a, \ + "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), \ + "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, \ + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device", ovo_streaming_csc_impl, + int, int); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csc_device", ovo_streaming_csc_impl, + int64_t, int64_t); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device", ovo_streaming_csr_impl, + int, int); + RSC_OVO_DEVICE_BINDING("ovo_streaming_csr_device", ovo_streaming_csr_impl, + int64_t, int64_t); +#undef RSC_OVO_DEVICE_BINDING + +#define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_map, \ + host_array h_grp_row_map, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, bool compute_nnz, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ + ovo_streaming_csc_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_ref_row_map.data(), h_grp_row_map.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_nnz.data(), n_ref, n_all_grp, n_rows, n_cols, \ + n_groups, n_groups_stats, compute_tie_corr, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, "d_group_nnz"_a, \ + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, \ + "n_groups"_a, "n_groups_stats"_a, "compute_tie_corr"_a, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", double, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int64_t, int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", double, int64_t, + int64_t); +#undef RSC_OVO_CSC_HOST_BINDING + +#define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_nnz, int n_full_rows, \ + int n_ref, int n_all_grp, int n_cols, int n_test, \ + int n_groups_stats, bool compute_tie_corr, bool compute_nnz, \ + bool compute_sums, int sub_batch_cols) { \ + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; \ + ovo_streaming_csr_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), n_full_rows, \ + h_ref_row_ids.data(), n_ref, h_grp_row_ids.data(), \ + h_grp_offsets.data(), n_all_grp, n_test, d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), d_group_nnz.data(), \ + n_cols, n_groups_stats, compute_tie_corr, compute_nnz, \ + compute_sums, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ + "h_grp_row_ids"_a, "h_grp_offsets"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_full_rows"_a, \ + "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_test"_a, "n_groups_stats"_a, \ + "compute_tie_corr"_a, "compute_nnz"_a = true, "compute_sums"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", double, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int64_t, int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", double, int64_t, + int64_t); +#undef RSC_OVO_CSR_HOST_BINDING +} + +NB_MODULE(_wilcoxon_sparse_cuda, m) { + REGISTER_GPU_BINDINGS(register_sparse_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh new file mode 100644 index 00000000..a5078997 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -0,0 +1,317 @@ +#pragma once + +#include + +#include "wilcoxon_block_reduce.cuh" +#include "wilcoxon_ovr_tie_walk.cuh" + +// Sparse OVR rank for nonnegative stored values; zeros rank analytically. +// CRITICAL: negative rejection and gmem fallback are required at large +// n_groups. +template +__global__ void rank_sums_sparse_ovr_kernel( + const float* __restrict__ sorted_vals, + const IndexT* __restrict__ sorted_row_idx, + const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, const double* __restrict__ group_sizes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, + double* __restrict__ nz_count_scratch, int n_rows, int sb_cols, + int n_groups, bool compute_tie_corr, bool use_gmem) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + int nnz_stored = seg_end - seg_start; + + const float* sv = sorted_vals + seg_start; + const IndexT* si = sorted_row_idx + seg_start; + + extern __shared__ double smem[]; + double* grp_sums; + double* grp_nz_count; + // Accumulator stride: 1 for shared mem (dense per-block), sb_cols for + // gmem (row-major layout (n_groups, sb_cols) shared across blocks). + int acc_stride; + + if (use_gmem) { + // rank_sums doubles as accumulator (pre-zeroed by caller). + grp_sums = rank_sums + (size_t)col; + grp_nz_count = nz_count_scratch + (size_t)col; + acc_stride = sb_cols; + } else { + grp_sums = smem; + grp_nz_count = smem + n_groups; + acc_stride = 1; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + grp_nz_count[g] = 0.0; + } + __syncthreads(); + } + + // pos_start = first index where sv[i] > 0 (stored zeros precede positives). + __shared__ int sh_pos_start; + if (threadIdx.x == 0) { + int lo = 0, hi = nnz_stored; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] <= 0.0f) + lo = mid + 1; + else + hi = mid; + } + sh_pos_start = lo; + } + __syncthreads(); + + int pos_start = sh_pos_start; + int n_stored_zero = pos_start; + int n_implicit_zero = n_rows - nnz_stored; + int total_zero = n_implicit_zero + n_stored_zero; + double zero_avg_rank = (total_zero > 0) ? (total_zero + 1.0) / 2.0 : 0.0; + + // Positive rank offset: full_pos(i)=i+n_implicit_zero; tie group [a,b) + // avg_rank = n_implicit_zero + (a+b+1)/2. + int offset_pos = n_implicit_zero; + + // Count stored positives per group. + for (int i = pos_start + threadIdx.x; i < nnz_stored; i += blockDim.x) { + int grp = group_codes[si[i]]; + if (grp >= 0 && grp < n_groups) { + atomicAdd(&grp_nz_count[(size_t)grp * acc_stride], 1.0); + } + } + __syncthreads(); + + // Analytic zero contribution: each group's zeros all get zero_avg_rank. + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + double n_zero_in_g = + group_sizes[g] - grp_nz_count[(size_t)g * acc_stride]; + grp_sums[(size_t)g * acc_stride] = n_zero_in_g * zero_avg_rank; + } + __syncthreads(); + + // Walk stored positives and compute tie-averaged ranks. + int n_pos = nnz_stored - pos_start; + int chunk = (n_pos + blockDim.x - 1) / blockDim.x; + int my_start = pos_start + threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > nnz_stored) my_end = nnz_stored; + + double local_tie_sum = ovr_walk_tie_runs( + sv, si, group_codes, grp_sums, acc_stride, n_groups, my_start, my_end, + /*seg_floor=*/pos_start, /*seg_ceil=*/nnz_stored, + /*rank_offset=*/(double)offset_pos, compute_tie_corr); + + __syncthreads(); + + // Write rank sums to global output (smem path only — gmem path is direct) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * sb_cols + col] = grp_sums[g]; + } + } + + // Tie correction: warp + block reduction + if (compute_tie_corr) { + // Single zero tie block contributes once. + if (threadIdx.x == 0 && total_zero > 1) { + double tz = (double)total_zero; + local_tie_sum += tz * tz * tz - tz; + } + + // smem path: warp buf after both accumulator arrays (2 * n_groups). + // gmem path: accumulators are in gmem, warp buf starts at smem[0]. + int warp_buf_off = use_gmem ? 0 : 2 * n_groups; + double* warp_buf = smem + warp_buf_off; + + double v = wilcoxon_block_sum(local_tie_sum, warp_buf); + if (threadIdx.x == 0) tie_corr[col] = finalize_tie_corr(n_rows, v); + } +} + +// Shared sparse-OVR rank launch for all sparse OVR implementations. +// CRITICAL: keep the gmem fallback for large-n_groups perturbation DE. +template +static inline void launch_ovr_sparse_rank( + const float* sorted_vals, const ValT* sorted_row_idx, + const int* col_seg_offsets, const int* group_codes, + const double* group_sizes, double* rank_sums, double* tie_corr, + double* nz_count_scratch, int n_rows, int sb_cols, int n_groups, int tpb, + size_t smem_bytes, bool compute_tie_corr, bool use_gmem, + cudaStream_t stream) { + if (use_gmem) { + cudaMemsetAsync(rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream); + cudaMemsetAsync(nz_count_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), stream); + } + rank_sums_sparse_ovr_kernel<<>>( + sorted_vals, sorted_row_idx, col_seg_offsets, group_codes, group_sizes, + rank_sums, tie_corr, nz_count_scratch, n_rows, sb_cols, n_groups, + compute_tie_corr, use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); +} + +// CRITICAL: sparse stats gmem fallback is load-bearing for large n_groups. +// It selects the global accumulator when smem would exceed the per-block limit. +static size_t cast_accumulate_smem_config(int n_groups, bool compute_nnz, + bool compute_totals, bool& use_gmem) { + int n_arrays = 1 + (compute_nnz ? 1 : 0); + size_t need = (size_t)n_arrays * n_groups * sizeof(double); + if (compute_totals) need += WARP_REDUCE_BUF * sizeof(double); + if (need <= wilcoxon_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return compute_totals ? WARP_REDUCE_BUF * sizeof(double) : 0; +} + +// Shared cast+accumulate loop for sparse-OVR stats kernels. +// Casts to f32 for sort and atomically accumulates f64 sums/nnz. +template +__device__ __forceinline__ void accumulate_group_stats( + const InT* data_in, float* data_f32_out, const IndexT* indices, + int seg_start, int seg_end, const int* group_codes, double* sums, + double* nnz, int acc_stride, int n_groups, bool compute_nnz, + bool compute_totals, double& local_total_sum, double& local_total_nnz) { + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + if (compute_totals) { + local_total_sum += v; + if (compute_nnz && v != 0.0) local_total_nnz += 1.0; + } + int row = (int)indices[i]; + int g = group_codes[row]; + if (g >= 0 && g < n_groups) { + atomicAdd(&sums[(size_t)g * acc_stride], v); + if (compute_nnz && v != 0.0) + atomicAdd(&nnz[(size_t)g * acc_stride], 1.0); + } + } +} + +/** Pre-sort cast-and-accumulate kernel for sparse OVR streaming. + * Writes f32 sort keys and accumulates explicit-value sums/nnz in f64. */ +template +__global__ void ovr_cast_and_accumulate_sparse_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_nnz, double* __restrict__ total_sums, + double* __restrict__ total_nnz, int sb_cols, int n_groups, + bool compute_nnz = true, bool compute_totals = false) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + // Packed layout matching cast_accumulate_smem_config ((1+compute_nnz)* + // n_groups doubles). + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_nnz = smem + n_groups; + double* warp_buf = smem + (size_t)(1 + (compute_nnz ? 1 : 0)) * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + double local_total_sum = 0.0; + double local_total_nnz = 0.0; + accumulate_group_stats( + data_in, data_f32_out, indices, seg_start, seg_end, group_codes, s_sum, + s_nnz, /*acc_stride=*/1, n_groups, compute_nnz, compute_totals, + local_total_sum, local_total_nnz); + __syncthreads(); + + if (compute_totals) { + double total = wilcoxon_block_sum(local_total_sum, warp_buf); + if (threadIdx.x == 0) total_sums[col] = total; + __syncthreads(); + if (compute_nnz) { + double nnz_total = wilcoxon_block_sum(local_total_nnz, warp_buf); + if (threadIdx.x == 0) total_nnz[col] = nnz_total; + __syncthreads(); + } + } + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +// CRITICAL: gmem stats accumulator for n_groups too large for smem. +// Required for Perturb-seq-scale group counts. +template +__global__ void ovr_cast_and_accumulate_sparse_global_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_nnz, double* __restrict__ total_sums, + double* __restrict__ total_nnz, int sb_cols, int n_groups, + bool compute_nnz = true, bool compute_totals = false) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + extern __shared__ double warp_buf[]; + double local_total_sum = 0.0; + double local_total_nnz = 0.0; + accumulate_group_stats( + data_in, data_f32_out, indices, seg_start, seg_end, group_codes, + group_sums + col, group_nnz + col, + /*acc_stride=*/sb_cols, n_groups, compute_nnz, compute_totals, + local_total_sum, local_total_nnz); + if (compute_totals) { + double total = wilcoxon_block_sum(local_total_sum, warp_buf); + if (threadIdx.x == 0) total_sums[col] = total; + __syncthreads(); + if (compute_nnz) { + double nnz_total = wilcoxon_block_sum(local_total_nnz, warp_buf); + if (threadIdx.x == 0) total_nnz[col] = nnz_total; + } + } +} + +template +static void launch_ovr_cast_and_accumulate_sparse( + const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, + const int* d_col_offsets, const int* d_group_codes, double* d_group_sums, + double* d_group_nnz, double* d_total_sums, double* d_total_nnz, int sb_cols, + int n_groups, bool compute_nnz, bool compute_totals, int tpb, + size_t smem_cast, bool use_gmem, cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_sparse_global_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_nnz, d_total_sums, + d_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_global_kernel); + } else { + ovr_cast_and_accumulate_sparse_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_nnz, d_total_sums, + d_total_nnz, sb_cols, n_groups, compute_nnz, compute_totals); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + } +} diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 0b9753a3..0ef1f7eb 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -21,6 +21,33 @@ ] +def _array_result_to_records( + arrays: dict[str, object], field: str, dtype: str | np.dtype +) -> np.ndarray: + group_names = tuple(str(name) for name in arrays["group_names"]) + values = np.asarray(arrays[field]) + out = np.empty( + values.shape[1], + dtype=[(group_name, np.dtype(dtype)) for group_name in group_names], + ) + for row, group_name in enumerate(group_names): + out[group_name] = values[row] + return out + + +def _array_result_to_names(arrays: dict[str, object]) -> np.ndarray: + group_names = tuple(str(name) for name in arrays["group_names"]) + var_names = np.asarray(arrays["var_names"]) + gene_indices = np.asarray(arrays["gene_indices"], dtype=np.intp) + out = np.empty( + gene_indices.shape[1], + dtype=[(group_name, object) for group_name in group_names], + ) + for row, group_name in enumerate(group_names): + out[group_name] = var_names[gene_indices[row]] + return out + + def rank_genes_groups( adata: AnnData, groupby: str, @@ -37,22 +64,44 @@ def rank_genes_groups( corr_method: _CorrMethod = "benjamini-hochberg", tie_correct: bool = False, use_continuity: bool = False, + return_u_values: bool = False, layer: str | None = None, chunk_size: int | None = None, - pre_load: bool = False, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + skip_empty_groups: bool = False, **kwds, ) -> None: """ Rank genes for characterizing groups using GPU acceleration. - Expects logarithmized data. + Log1p/log-normalized data is expected for biologically meaningful log fold + changes. In-memory sparse ``wilcoxon`` inputs with explicit negative values + use sign-safe dense ranking in the CUDA sparse streamers, materializing + bounded dense tiles inside the nanobind path. Dense inputs are ranked + directly and support any sign. + (``wilcoxon_binned`` rejects negative Dask sparse input, which it cannot + bin correctly.) .. note:: - **Dask support:** `'t-test'`, `'t-test_overestim_var'`, and - `'wilcoxon_binned'` support Dask arrays. The `'wilcoxon'` and - `'logreg'` methods do not support Dask arrays. + **Dask support:** `'t-test'`, `'t-test_overestim_var'`, + `'wilcoxon_binned'`, and `'logreg'` support Dask arrays. The + `'wilcoxon'` method does not support Dask arrays. + + .. note:: + **Wilcoxon ranking precision:** `'wilcoxon'` and `'wilcoxon_binned'` + rank values in float32 on every code path, while means and log fold + changes are computed in float64. This only diverges from Scanpy when the + **preprocessing itself ran in float64** — i.e. normalization/log1p + produced values carrying sub-float32 precision. If preprocessing was + done in float32 (the common case), the values are float32-exact and + ranking is bit-identical to Scanpy (~1e-13), even if they are afterward + stored as float64. For a fully float64 pipeline the rank-derived scores + and p-values still match Scanpy-on-float64 to ~1e-4 on log-normalized + data — below any significance threshold and changing no DE calls — + because the rank-sum normal approximation is insensitive to sub-float32 + tie jitter. If exact float64 ranking matters for your workflow, please + open an issue at https://github.com/scverse/rapids_singlecell/issues. Parameters ---------- @@ -101,15 +150,17 @@ def rank_genes_groups( z-scores. Subtracts 0.5 from ``|R - E[R]|`` before dividing by the standard deviation, matching :func:`scipy.stats.mannwhitneyu` default behavior. + return_u_values + For `'wilcoxon'`, store Mann-Whitney U statistics in `scores` instead + of z-scores. P-values are still computed from the z-score normal + approximation using the selected tie and continuity settings. layer Key from `adata.layers` whose value will be used to perform tests on. chunk_size Number of genes to process at once for `'wilcoxon'` and - `'wilcoxon_binned'`. Default is 128 for `'wilcoxon'`. For + `'wilcoxon_binned'`. Default is 512 for `'wilcoxon'`. For `'wilcoxon_binned'` the default is sized dynamically based on ``n_groups`` and ``n_bins`` to keep histogram memory stable. - pre_load - Pre-load the data into GPU memory. Used only for `'wilcoxon'`. n_bins Number of histogram bins for `'wilcoxon_binned'`. Higher values give a better approximation at slightly increased cost. Default is 1000 @@ -119,15 +170,20 @@ def rank_genes_groups( ``None`` (default) uses ``'auto'`` for in-memory arrays and ``'log1p'`` for Dask arrays (to avoid a costly data scan). ``'log1p'`` uses a fixed [0, 15] range suitable for most log1p-normalized data. - ``'auto'`` computes the actual data range. Use this for z-scored - or unnormalized data. + ``'auto'`` computes the actual data range. Use this for nonnegative + expression data outside the fixed log1p range. + skip_empty_groups + Skip selected groups with fewer than two observations after filtering. + This is useful for perturbation workflows where a per-cell-type slice + keeps categories that are empty or singleton in that slice. **kwds Additional arguments passed to the method. For `'logreg'`, these are passed to :class:`cuml.linear_model.LogisticRegression`. Returns ------- - Updates `adata` with the following fields: + Updates `adata` with the following fields. Rank result fields are + Scanpy-compatible structured arrays. `adata.uns['rank_genes_groups' | key_added]['names']` Structured array to be indexed by group id storing the gene @@ -135,7 +191,8 @@ def rank_genes_groups( `adata.uns['rank_genes_groups' | key_added]['scores']` Structured array to be indexed by group id storing the z-score underlying the computation of a p-value for each gene for each - group. Ordered according to scores. + group, or the Mann-Whitney U statistic when + `return_u_values=True`. Ordered according to scores. `adata.uns['rank_genes_groups' | key_added]['logfoldchanges']` Structured array to be indexed by group id storing the log2 fold change for each gene for each group. @@ -154,6 +211,13 @@ def rank_genes_groups( msg = "corr_method must be either 'benjamini-hochberg' or 'bonferroni'." raise ValueError(msg) + if "return_format" in kwds: + msg = ( + "return_format has been removed; rank_genes_groups always writes " + "Scanpy-compatible structured results to adata.uns." + ) + raise TypeError(msg) + if method is None: method = "t-test" @@ -170,10 +234,17 @@ def rank_genes_groups( ) raise ValueError(msg) + if return_u_values and method != "wilcoxon": + msg = "return_u_values is only supported for method='wilcoxon'." + raise ValueError(msg) + + if chunk_size is not None and chunk_size <= 0: + msg = "chunk_size must be a positive integer." + raise ValueError(msg) + if key_added is None: key_added = "rank_genes_groups" - # Process mask_var: convert string to boolean array mask_var_array: NDArray[np.bool_] | None = None if mask_var is not None: if isinstance(mask_var, str): @@ -196,10 +267,9 @@ def rank_genes_groups( use_raw=use_raw, layer=layer, comp_pts=pts, - pre_load=pre_load, + skip_empty_groups=skip_empty_groups, ) - # Determine n_genes_user n_genes_user = n_genes if n_genes_user is None or n_genes_user > test_obj.X.shape[1]: n_genes_user = test_obj.X.shape[1] @@ -211,25 +281,14 @@ def rank_genes_groups( rankby_abs=rankby_abs, tie_correct=tie_correct, use_continuity=use_continuity, + return_u_values=return_u_values, chunk_size=chunk_size, n_bins=n_bins, bin_range=bin_range, **kwds, ) - # Build output - test_obj.stats.columns = test_obj.stats.columns.swaplevel() - - dtypes = { - "names": "U50", - "scores": "float32", - "logfoldchanges": "float32", - "pvals": "float64", - "pvals_adj": "float64", - } - - adata.uns[key_added] = {} - adata.uns[key_added]["params"] = { + params = { "groupby": groupby, "reference": reference, "method": method, @@ -237,10 +296,22 @@ def rank_genes_groups( "layer": layer, "corr_method": corr_method, } + if method == "wilcoxon": + params["tie_correct"] = tie_correct + params["return_u_values"] = return_u_values + + arrays = test_obj.stats_arrays or {} + adata.uns[key_added] = {"params": params} + if arrays and len(arrays.get("group_names", ())) > 0: + adata.uns[key_added]["names"] = _array_result_to_names(arrays) + for col in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + if col in arrays: + values = arrays[col] + dtype = values.dtype + adata.uns[key_added][col] = _array_result_to_records(arrays, col, dtype) - # Store pts results if computed + groups_names = [str(name) for name in test_obj.groups_order] if test_obj.pts is not None: - groups_names = [str(name) for name in test_obj.groups_order] adata.uns[key_added]["pts"] = pd.DataFrame( test_obj.pts.T, index=test_obj.var_names, columns=groups_names ) @@ -249,14 +320,7 @@ def rank_genes_groups( test_obj.pts_rest.T, index=test_obj.var_names, columns=groups_names ) - if method == "wilcoxon": - adata.uns[key_added]["params"]["tie_correct"] = tie_correct - - for col in test_obj.stats.columns.levels[0]: - if col in dtypes: - adata.uns[key_added][col] = test_obj.stats[col].to_records( - index=False, column_dtypes=dtypes[col] - ) + return None if TYPE_CHECKING: @@ -285,7 +349,7 @@ def rank_genes_groups_logreg( layer: str | None = None, **kwds, ) -> None: - rank_genes_groups( + return rank_genes_groups( adata, groupby, groups=groups, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index 2ccf87b5..6d39ab60 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -1,18 +1,27 @@ from __future__ import annotations +import os +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Literal, assert_never import cupy as cp import numpy as np import pandas as pd -from statsmodels.stats.multitest import multipletests from rapids_singlecell._compat import DaskArray from rapids_singlecell.get import X_to_GPU from rapids_singlecell.get._aggregated import Aggregate from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import EPS, _select_groups, _select_top_n +from ._utils import ( + EPS, + _canonicalize_sparse, + _select_groups, + _sparse_has_negative, +) + +_RANK_SORT_MIN_ELEMENTS = 1_000_000 +_RANK_SORT_MAX_WORKERS = 64 if TYPE_CHECKING: from collections.abc import Iterable @@ -37,9 +46,8 @@ def __init__( use_raw: bool | None = None, layer: str | None = None, comp_pts: bool = False, - pre_load: bool = False, + skip_empty_groups: bool = False, ) -> None: - # Handle groups parameter if groups == "all" or groups is None: selected: list | None = None elif isinstance(groups, str | int): @@ -63,10 +71,12 @@ def __init__( raise ValueError(msg) self.groups_order, self.group_codes, self.group_sizes = _select_groups( - self.labels, selected + self.labels, + selected, + reference=reference, + skip_empty_groups=skip_empty_groups, ) - # Get data matrix if layer is not None: if use_raw is True: msg = "Cannot specify `layer` and have `use_raw=True`." @@ -86,26 +96,23 @@ def __init__( self.X = adata.X self.var_names = adata.var_names - # Apply mask_var to select subset of genes if mask_var is not None: self.X = self.X[:, mask_var] self.var_names = self.var_names[mask_var] - self.pre_load = pre_load - self.ireference = None if reference != "rest": self.ireference = int(np.where(self.groups_order == str(reference))[0][0]) - # Set up expm1 function based on log base + # expm1 function depends on the log base used by log1p self.is_log1p = "log1p" in adata.uns base = adata.uns.get("log1p", {}).get("base") + self._log1p_base = base if base is not None: self.expm1_func = lambda x: np.expm1(x * np.log(base)) else: self.expm1_func = np.expm1 - # For basic stats self.comp_pts = comp_pts self.means: np.ndarray | None = None self.vars: np.ndarray | None = None @@ -114,9 +121,14 @@ def __init__( self.vars_rest: np.ndarray | None = None self.pts_rest: np.ndarray | None = None - self.stats: pd.DataFrame | None = None + self.stats_arrays: dict[str, object] | None = None + self._sparse_negative_fallback = False + self._store_wilcoxon_gpu_result = False + self._wilcoxon_gpu_result: ( + tuple[np.ndarray, cp.ndarray, cp.ndarray, cp.ndarray | None] | None + ) = None self._compute_stats_in_chunks: bool = False - self._ref_chunk_computed: set[int] = set() + self._score_dtype = np.dtype(np.float32) def _init_stats_arrays(self, n_genes: int) -> None: """Pre-allocate stats arrays before chunk loop.""" @@ -143,13 +155,9 @@ def _init_stats_arrays(self, n_genes: int) -> None: def _basic_stats(self) -> None: """Compute means, vars, and pts for each group. - - If data is already on GPU, uses Aggregate for fast single-pass computation. - Otherwise, sets flag for chunk-based computation during the wilcoxon loop. - """ + Host data defers stats to the Wilcoxon chunk/streaming path.""" n_genes = self.X.shape[1] - # Check if data is already on GPU try: _check_gpu_X(self.X, allow_dask=True) except TypeError: @@ -158,12 +166,11 @@ def _basic_stats(self) -> None: is_on_gpu = True if not is_on_gpu: - # Data not on GPU - defer to chunk-based computation + # Not on GPU: defer to chunk-based computation in the wilcoxon loop self._compute_stats_in_chunks = True self._init_stats_arrays(n_genes) return - # Data is on GPU - use Aggregate for fast computation self._compute_stats_in_chunks = False agg = Aggregate(groupby=self.labels.cat, data=self.X) @@ -179,9 +186,8 @@ def _basic_stats(self) -> None: cat_to_idx = {str(name): i for i, name in enumerate(cat_names)} order = [cat_to_idx[str(name)] for name in self.groups_order] - # Aggregate returns stats per ALL categories. Slice to selected groups - # for per-group means/vars; keep the all-category arrays for "rest" - # stats so the totals stay correct when ``groups`` is a strict subset. + # Aggregate returns all categories; slice selected groups for outputs. + # Keep all-category totals so ``groups`` subsets get correct rest stats. sums_all = result["sum"] sq_sums_all = result["sq_sum"] nnz_all = result["count_nonzero"] if self.comp_pts else None @@ -190,7 +196,6 @@ def _basic_stats(self) -> None: sums = sums_all[order] sq_sums = sq_sums_all[order] - # Compute means and variances from raw sums (all on GPU) means = sums / n group_ss = sq_sums - n * means**2 vars_ = cp.maximum(group_ss / cp.maximum(n - 1, 1), 0) @@ -200,9 +205,8 @@ def _basic_stats(self) -> None: else: pts = None - # Compute rest statistics if reference='rest' — "rest" means every - # cell in ``groupby`` not in this group, including cells in - # categories that weren't selected via ``groups=``. + # For reference='rest', rest includes every category not in this group. + # That includes categories omitted by a strict ``groups=`` selection. if self.ireference is None: n_total = agg.n_cells.sum() n_rest = n_total - n @@ -225,7 +229,6 @@ def _basic_stats(self) -> None: self.vars_rest = None self.pts_rest = None - # Transfer to CPU self.means = cp.asnumpy(means) self.vars = cp.asnumpy(vars_) self.pts = cp.asnumpy(pts) if pts is not None else None @@ -236,7 +239,7 @@ def _accumulate_chunk_stats_vs_rest( start: int, stop: int, *, - group_matrix: cp.ndarray, + group_codes_dev: cp.ndarray, group_sizes_dev: cp.ndarray, n_cells: int, ) -> None: @@ -246,25 +249,36 @@ def _accumulate_chunk_stats_vs_rest( rest_sizes = n_cells - group_sizes_dev - # Group sums and sum of squares - group_sums = group_matrix.T @ block - group_sum_sq = group_matrix.T @ (block**2) + n_groups = len(self.groups_order) + n_cols = stop - start + group_sums = cp.zeros((n_groups, n_cols), dtype=cp.float64) + group_sum_sq = cp.zeros((n_groups, n_cols), dtype=cp.float64) + group_nnz = ( + cp.zeros((n_groups, n_cols), dtype=cp.float64) if self.comp_pts else None + ) + from rapids_singlecell._cuda import _rank_stats_cuda as _rs + + _rs.group_chunk_stats( + block, + group_codes_dev, + group_sums, + group_sum_sq, + group_nnz if group_nnz is not None else group_sums, + compute_nnz=bool(self.comp_pts), + stream=cp.cuda.get_current_stream().ptr, + ) - # Means chunk_means = group_sums / group_sizes_dev[:, None] self.means[:, start:stop] = cp.asnumpy(chunk_means) - # Variances (with Bessel correction) + # variance with Bessel correction chunk_vars = group_sum_sq / group_sizes_dev[:, None] - chunk_means**2 chunk_vars *= group_sizes_dev[:, None] / (group_sizes_dev[:, None] - 1) self.vars[:, start:stop] = cp.asnumpy(chunk_vars) - # Pts (fraction expressing) if self.comp_pts: - group_nnz = group_matrix.T @ (block != 0).astype(cp.float64) self.pts[:, start:stop] = cp.asnumpy(group_nnz / group_sizes_dev[:, None]) - # Rest statistics if self.ireference is None: total_sum = block.sum(axis=0) total_sum_sq = (block**2).sum(axis=0) @@ -285,49 +299,6 @@ def _accumulate_chunk_stats_vs_rest( rest_nnz / rest_sizes[:, None] ) - def _accumulate_chunk_stats_with_ref( - self, - block: cp.ndarray, - start: int, - stop: int, - *, - group_index: int, - group_mask_gpu: cp.ndarray, - n_group: int, - n_ref: int, - ) -> None: - """Compute and store stats for one gene chunk (with reference mode).""" - if not self._compute_stats_in_chunks: - return # Stats already computed via Aggregate - - # Group stats - group_data = block[group_mask_gpu] - group_mean = group_data.mean(axis=0) - self.means[group_index, start:stop] = cp.asnumpy(group_mean) - - if n_group > 1: - group_var = group_data.var(axis=0, ddof=1) - self.vars[group_index, start:stop] = cp.asnumpy(group_var) - - if self.comp_pts: - group_nnz = (group_data != 0).sum(axis=0) - self.pts[group_index, start:stop] = cp.asnumpy(group_nnz / n_group) - - # Reference stats (only compute once, on first non-reference group) - if start not in self._ref_chunk_computed: - self._ref_chunk_computed.add(start) - ref_data = block[~group_mask_gpu] - ref_mean = ref_data.mean(axis=0) - self.means[self.ireference, start:stop] = cp.asnumpy(ref_mean) - - if n_ref > 1: - ref_var = ref_data.var(axis=0, ddof=1) - self.vars[self.ireference, start:stop] = cp.asnumpy(ref_var) - - if self.comp_pts: - ref_nnz = (ref_data != 0).sum(axis=0) - self.pts[self.ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) - def t_test( self, method: Literal["t-test", "t-test_overestim_var"] ) -> list[tuple[int, NDArray, NDArray]]: @@ -342,6 +313,7 @@ def wilcoxon( tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" from ._wilcoxon import wilcoxon @@ -351,6 +323,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) def wilcoxon_binned( @@ -392,27 +365,47 @@ def compute_statistics( chunk_size: int | None = None, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + return_u_values: bool = False, **kwds, ) -> None: """Compute statistics for all groups.""" - if self.pre_load or method in { + # Sparse Wilcoxon handles implicit zeros analytically only for nonnegative data. + # Signed sparse Wilcoxon routes to sign-safe dense ranking inside streamers. + self._sparse_negative_fallback = False + if method in {"wilcoxon", "wilcoxon_binned"}: + # Canonicalize before the negative check because summing duplicates can change signs. + # Fast paths rank stored nnz once, so they must see scanpy's summed view. + self.X = _canonicalize_sparse(self.X) + self._sparse_negative_fallback = _sparse_has_negative(self.X) + if method in { "t-test", "t-test_overestim_var", "wilcoxon_binned", }: self.X = X_to_GPU(self.X) + n_genes = self.X.shape[1] + if n_genes_user is None: + n_genes_user = n_genes + if method in {"t-test", "t-test_overestim_var"}: test_results = self.t_test(method) elif method == "wilcoxon": if isinstance(self.X, DaskArray): msg = "Wilcoxon test is not supported for Dask arrays. Please convert your data to CuPy arrays." raise ValueError(msg) - test_results = self.wilcoxon( - tie_correct=tie_correct, - use_continuity=use_continuity, - chunk_size=chunk_size, - ) + self._score_dtype = np.dtype(np.float64 if return_u_values else np.float32) + self._wilcoxon_gpu_result = None + self._store_wilcoxon_gpu_result = True + try: + test_results = self.wilcoxon( + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + finally: + self._store_wilcoxon_gpu_result = False elif method == "wilcoxon_binned": test_results = self.wilcoxon_binned( tie_correct=tie_correct, @@ -426,58 +419,216 @@ def compute_statistics( else: assert_never(method) - n_genes = self.X.shape[1] + if not test_results and self._wilcoxon_gpu_result is None: + self.stats_arrays = { + "group_indices": np.empty(0, dtype=np.intp), + "group_names": np.empty(0, dtype=object), + "var_names": np.asarray(self.var_names), + "gene_indices": np.empty((0, n_genes_user), dtype=np.intp), + } + return + + if self._wilcoxon_gpu_result is not None: + group_indices, scores_gpu, pvals_gpu, logfoldchanges_gpu = ( + self._wilcoxon_gpu_result + ) + try: + self._compute_statistics_gpu_arrays( + group_indices, + scores_gpu, + pvals_gpu, + logfoldchanges_gpu, + corr_method=corr_method, + n_genes_user=n_genes_user, + n_genes=n_genes, + rankby_abs=rankby_abs, + ) + finally: + self._wilcoxon_gpu_result = None + return - # Collect all stats data first to avoid DataFrame fragmentation - stats_data: dict[tuple[str, str], np.ndarray] = {} + self._compute_statistics_arrays( + test_results, + corr_method=corr_method, + n_genes_user=n_genes_user, + n_genes=n_genes, + rankby_abs=rankby_abs, + ) - for group_index, scores, pvals in test_results: - group_name = str(self.groups_order[group_index]) + @staticmethod + def _rank_indices_matrix(scores: np.ndarray, n_top: int) -> np.ndarray: + if n_top >= scores.shape[1]: + return _RankGenes._argsort_desc_matrix(scores) + partition = np.argpartition(scores, -n_top, axis=1)[:, -n_top:] + row_ids = np.arange(scores.shape[0])[:, None] + order = np.argsort(scores[row_ids, partition], axis=1)[:, ::-1] + return partition[row_ids, order] + + @staticmethod + def _argsort_desc_matrix(scores: np.ndarray) -> np.ndarray: + n_rows, n_cols = scores.shape + n_elements = n_rows * n_cols + n_workers = min(_RANK_SORT_MAX_WORKERS, os.cpu_count() or 1, n_rows) + if n_workers <= 1 or n_elements < _RANK_SORT_MIN_ELEMENTS: + return np.argsort(scores, axis=1)[:, ::-1] + + chunks = np.linspace(0, n_rows, n_workers + 1, dtype=np.intp) + indices = np.empty((n_rows, n_cols), dtype=np.intp) + + def sort_chunk(chunk_index: int) -> None: + start = int(chunks[chunk_index]) + stop = int(chunks[chunk_index + 1]) + if start < stop: + indices[start:stop] = np.argsort(scores[start:stop], axis=1)[:, ::-1] + + with ThreadPoolExecutor(max_workers=n_workers) as executor: + list(executor.map(sort_chunk, range(n_workers))) + return indices + + @staticmethod + def _fdr_bh_matrix(pvals: np.ndarray) -> np.ndarray: + pvals_clean = np.array(pvals, copy=True) + pvals_clean[np.isnan(pvals_clean)] = 1.0 + order = np.argsort(pvals_clean, axis=1) + sorted_p = np.take_along_axis(pvals_clean, order, axis=1) + n_tests = sorted_p.shape[1] + scale = n_tests / np.arange(1, n_tests + 1, dtype=np.float64) + corrected_sorted = sorted_p * scale + corrected_sorted = np.minimum.accumulate(corrected_sorted[:, ::-1], axis=1)[ + :, ::-1 + ] + corrected_sorted[corrected_sorted > 1.0] = 1.0 + corrected = np.empty_like(corrected_sorted) + np.put_along_axis(corrected, order, corrected_sorted, axis=1) + return corrected + + @staticmethod + def _fdr_bh_matrix_gpu(pvals: cp.ndarray) -> cp.ndarray: + pvals_clean = cp.nan_to_num(pvals, nan=1.0) + order = cp.argsort(pvals_clean, axis=1) + corrected_sorted = cp.take_along_axis(pvals_clean, order, axis=1) + corrected_sorted *= corrected_sorted.shape[1] / cp.arange( + 1, corrected_sorted.shape[1] + 1, dtype=cp.float64 + ) + from rapids_singlecell._cuda import _rank_stats_cuda as _rs - if n_genes_user is not None: - scores_sort = np.abs(scores) if rankby_abs else scores - global_indices = _select_top_n(scores_sort, n_genes_user) + _rs.fdr_bh_reverse_cummin( + corrected_sorted, stream=cp.cuda.get_current_stream().ptr + ) + corrected = cp.empty_like(corrected_sorted) + cp.put_along_axis(corrected, order, corrected_sorted, axis=1) + return corrected + + def _logfoldchanges_into( + self, arrays: dict, group_indices: np.ndarray, top_idx: np.ndarray + ) -> None: + mean_group = self.means[group_indices] + if self.ireference is None: + mean_rest = self.means_rest[group_indices] + else: + mean_rest = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_group) + EPS) / ( + self.expm1_func(mean_rest) + EPS + ) + logfoldchanges = np.log2(foldchanges) + arrays["logfoldchanges"] = np.take_along_axis( + logfoldchanges, top_idx, axis=1 + ).astype(np.float32, copy=False) + + def _compute_statistics_arrays( + self, + test_results: list[tuple[int, NDArray, NDArray]], + *, + corr_method: _CorrMethod, + n_genes_user: int, + n_genes: int, + rankby_abs: bool, + ) -> None: + group_indices = np.asarray([r[0] for r in test_results], dtype=np.intp) + scores = np.vstack([r[1] for r in test_results]) + sort_scores = np.abs(scores) if rankby_abs else scores + top_idx = self._rank_indices_matrix(sort_scores, n_genes_user) + + arrays: dict[str, object] = { + "group_indices": group_indices, + "group_names": np.asarray( + [str(self.groups_order[i]) for i in group_indices], dtype=object + ), + "var_names": np.asarray(self.var_names), + "gene_indices": top_idx.astype(np.intp, copy=False), + "scores": np.take_along_axis(scores, top_idx, axis=1).astype( + self._score_dtype, copy=False + ), + } + + if test_results[0][2] is not None: + pvals = np.vstack([r[2] for r in test_results]) + arrays["pvals"] = np.take_along_axis(pvals, top_idx, axis=1) + if corr_method == "benjamini-hochberg": + pvals_adj = self._fdr_bh_matrix(pvals) + elif corr_method == "bonferroni": + pvals_adj = np.minimum(pvals * n_genes, 1.0) else: - global_indices = slice(None) - - if n_genes_user is not None: - stats_data[group_name, "names"] = np.asarray(self.var_names)[ - global_indices - ] - - stats_data[group_name, "scores"] = scores[global_indices] - - if pvals is not None: - stats_data[group_name, "pvals"] = pvals[global_indices] - if corr_method == "benjamini-hochberg": - pvals_clean = np.array(pvals, copy=True) - pvals_clean[np.isnan(pvals_clean)] = 1.0 - _, pvals_adj, _, _ = multipletests( - pvals_clean, alpha=0.05, method="fdr_bh" - ) - elif corr_method == "bonferroni": - pvals_adj = np.minimum(pvals * n_genes, 1.0) - stats_data[group_name, "pvals_adj"] = pvals_adj[global_indices] - - # Compute logfoldchanges - if self.means is not None: - mean_group = self.means[group_index] - if self.ireference is None: - mean_rest = self.means_rest[group_index] - else: - mean_rest = self.means[self.ireference] - foldchanges = (self.expm1_func(mean_group) + EPS) / ( - self.expm1_func(mean_rest) + EPS + msg = f"Unsupported correction method: {corr_method!r}." + raise ValueError(msg) + arrays["pvals_adj"] = np.take_along_axis(pvals_adj, top_idx, axis=1) + + if self.means is not None: + self._logfoldchanges_into(arrays, group_indices, top_idx) + + self.stats_arrays = arrays + + def _compute_statistics_gpu_arrays( + self, + group_indices: np.ndarray, + scores_gpu: cp.ndarray, + pvals_gpu: cp.ndarray, + logfoldchanges_gpu: cp.ndarray | None, + *, + corr_method: _CorrMethod, + n_genes_user: int, + n_genes: int, + rankby_abs: bool, + ) -> None: + group_indices = np.asarray(group_indices, dtype=np.intp) + scores = cp.asnumpy(scores_gpu) + sort_scores = np.abs(scores) if rankby_abs else scores + top_idx = self._rank_indices_matrix(sort_scores, n_genes_user) + top_idx_gpu = cp.asarray(top_idx) + + arrays: dict[str, object] = { + "group_indices": group_indices, + "group_names": np.asarray( + [str(self.groups_order[i]) for i in group_indices], dtype=object + ), + "var_names": np.asarray(self.var_names), + "gene_indices": top_idx.astype(np.intp, copy=False), + "scores": cp.asnumpy( + cp.take_along_axis(scores_gpu, top_idx_gpu, axis=1).astype( + self._score_dtype, copy=False ) - stats_data[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] + ), + "pvals": cp.asnumpy(cp.take_along_axis(pvals_gpu, top_idx_gpu, axis=1)), + } + + if corr_method == "benjamini-hochberg": + pvals_adj_gpu = self._fdr_bh_matrix_gpu(pvals_gpu) + elif corr_method == "bonferroni": + pvals_adj_gpu = cp.minimum(pvals_gpu * n_genes, 1.0) + else: + msg = f"Unsupported correction method: {corr_method!r}." + raise ValueError(msg) + arrays["pvals_adj"] = cp.asnumpy( + cp.take_along_axis(pvals_adj_gpu, top_idx_gpu, axis=1) + ) + + if logfoldchanges_gpu is not None: + arrays["logfoldchanges"] = cp.asnumpy( + cp.take_along_axis(logfoldchanges_gpu, top_idx_gpu, axis=1).astype( + cp.float32, copy=False ) + ) + elif self.means is not None: + self._logfoldchanges_into(arrays, group_indices, top_idx) - # Create DataFrame all at once to avoid fragmentation - if stats_data: - self.stats = pd.DataFrame(stats_data) - self.stats.columns = pd.MultiIndex.from_tuples(self.stats.columns) - if n_genes_user is None: - self.stats.index = self.var_names - else: - self.stats = None + self.stats_arrays = arrays diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py b/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py index d4bf0dc3..d2decc70 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import cupy as cp +import numpy as np from rapids_singlecell._compat import DaskArray, _meta_dense @@ -21,7 +22,16 @@ def logreg(rg: _RankGenes, **kwds) -> list[tuple[int, NDArray, None]]: n_groups = len(rg.groups_order) selected = rg.group_codes < n_groups X = rg.X[selected, :] - grouping_logreg = rg.group_codes[selected].astype(X.dtype) + codes = rg.group_codes[selected] + + # Encode multinomial classes in original category order for cuML softmax. + # groups_order follows the user; coef_ rows are mapped back below. + cat_order = {str(c): i for i, c in enumerate(rg.labels.cat.categories)} + canon_key = np.array([cat_order[str(g)] for g in rg.groups_order]) + canon_label = np.empty(n_groups, dtype=np.int64) + canon_label[np.argsort(canon_key, kind="stable")] = np.arange(n_groups) + relabel = cp.asarray(canon_label) if isinstance(codes, cp.ndarray) else canon_label + grouping_logreg = relabel[codes].astype(X.dtype) if isinstance(X, DaskArray): import dask.array as da @@ -46,7 +56,8 @@ def logreg(rg: _RankGenes, **kwds) -> list[tuple[int, NDArray, None]]: if n_groups <= 2: scores = scores_all[0].get() else: - scores = scores_all[igroup].get() + # coef_ rows are in canonical class order; map back to groups_order. + scores = scores_all[int(canon_label[igroup])].get() results.append((igroup, scores, None)) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index c4f2c601..19bdba49 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -7,49 +7,74 @@ import numpy as np import scipy.sparse as sp -from rapids_singlecell.preprocessing._utils import _sparse_to_dense - if TYPE_CHECKING: import pandas as pd from numpy.typing import NDArray EPS = 1e-9 -WARP_SIZE = 32 -MAX_THREADS_PER_BLOCK = 512 +MIN_GROUP_SIZE_WARNING = 25 + + +def _sparse_has_negative(X) -> bool: + """Return whether an in-memory sparse matrix stores a negative value. + Signed sparse Wilcoxon needs the sign-safe sparse-dense ranker.""" + if sp.issparse(X) or cpsp.issparse(X): + if np.dtype(X.data.dtype).kind == "c": + return False + return X.nnz > 0 and float(X.data.min()) < 0 + return False + + +def _canonicalize_sparse(X): + """Sum duplicates and sort sparse indices in place when needed. + Fast Wilcoxon ranks stored nnz once, so it expects scanpy's summed view.""" + if ( + (sp.issparse(X) or cpsp.issparse(X)) + and getattr(X, "format", None) in {"csr", "csc"} + and not X.has_canonical_format + ): + X.sum_duplicates() # also sorts indices and sets the canonical flag + return X def _select_groups( labels: pd.Series, selected: list | None, + *, + reference: str = "rest", + skip_empty_groups: bool = False, ) -> tuple[NDArray, NDArray[np.int32], NDArray[np.int64]]: - """Build integer group codes from a categorical Series. - - Parameters - ---------- - labels - Categorical Series (from ``adata.obs[groupby]``). - selected - Group names to keep, or ``None`` for all groups. - Must already include the reference group if applicable. - - Returns - ------- - groups_order - Selected group names as a numpy array. - group_codes - Per-cell int32 codes: ``0..n_groups-1`` for selected cells, - ``n_groups`` (sentinel) for unselected cells. - group_sizes - Number of cells per selected group (int64). - """ + """Build selected group names, per-cell int32 codes, and group sizes. + Unselected cells receive the sentinel code ``n_groups``.""" all_categories = labels.cat.categories if selected is None: selected = list(all_categories) - elif len(selected) > 1: - # Sort to match original category order (scanpy convention) - cat_order = {str(c): i for i, c in enumerate(all_categories)} - selected.sort(key=lambda x: cat_order.get(str(x), len(all_categories))) + # else: preserve the user-provided order. scanpy's select_groups does NOT + # re-sort to category order, so the output column order echoes `groups=`. + + if skip_empty_groups: + counts = { + str(name): int(count) for name, count in labels.value_counts().items() + } + valid_selected = [group for group in selected if counts.get(str(group), 0) >= 2] + if reference != "rest": + ref_matches = [group for group in selected if str(group) == str(reference)] + if ref_matches: + ref_group = ref_matches[0] + if ref_group not in valid_selected: + msg = ( + f"reference = {reference} has fewer than two samples after " + "filtering and cannot be used for rank_genes_groups." + ) + raise ValueError(msg) + selected = valid_selected + if len(selected) == 0: + msg = ( + "No groups with at least two samples remain after applying " + "skip_empty_groups=True." + ) + raise ValueError(msg) n_groups = len(selected) groups_order = np.array(selected) @@ -71,37 +96,17 @@ def _select_groups( np.int64 ) - # Validate singlet groups invalid_groups = {str(selected[i]) for i in range(n_groups) if group_sizes[i] < 2} if invalid_groups: msg = ( f"Could not calculate statistics for groups {', '.join(invalid_groups)} " - "since they only contain one sample." + "since they contain fewer than two samples." ) raise ValueError(msg) return groups_order, group_codes, group_sizes -def _round_up_to_warp(n: int) -> int: - """Round up to nearest multiple of WARP_SIZE, capped at MAX_THREADS_PER_BLOCK.""" - return min(MAX_THREADS_PER_BLOCK, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) - - -def _select_top_n(scores: NDArray, n_top: int) -> NDArray: - """Select indices of top n scores. - - Uses argpartition + argsort for O(n + k log k) complexity where k = n_top. - This is faster than full sorting when k << n. - """ - n_from = scores.shape[0] - reference_indices = np.arange(n_from, dtype=int) - partition = np.argpartition(scores, -n_top)[-n_top:] - partial_indices = np.argsort(scores[partition])[::-1] - global_indices = reference_indices[partition][partial_indices] - return global_indices - - def _choose_chunk_size(requested: int | None) -> int: """Choose chunk size for gene processing.""" if requested is not None: @@ -110,35 +115,67 @@ def _choose_chunk_size(requested: int | None) -> int: def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray: - """ - Extract columns from a CSC matrix via direct indptr pointer slicing. + """Densify a CSC column window into an F-order float64 GPU block. + Slices by indptr so only window nonzeros are touched/transferred.""" + from rapids_singlecell._cuda import _rank_stats_cuda as _rs - Works for both scipy and CuPy CSC matrices. Much faster than - ``X[:, start:stop]`` which rebuilds index arrays internally. - """ s_ptr = int(X_csc.indptr[start]) e_ptr = int(X_csc.indptr[stop]) - chunk_data = cp.asarray(X_csc.data[s_ptr:e_ptr]) - chunk_indices = cp.asarray(X_csc.indices[s_ptr:e_ptr]) - chunk_indptr = cp.asarray(X_csc.indptr[start : stop + 1] - s_ptr) - csc_chunk = cpsp.csc_matrix( - (chunk_data, chunk_indices, chunk_indptr), shape=(n_rows, stop - start) + out = cp.zeros((n_rows, stop - start), dtype=cp.float64, order="F") + if e_ptr > s_ptr: + chunk_data = cp.asarray(X_csc.data[s_ptr:e_ptr]) + chunk_indices = cp.asarray(X_csc.indices[s_ptr:e_ptr]) + chunk_indptr = cp.asarray(X_csc.indptr[start : stop + 1] - s_ptr) + _rs.csc_tile_to_dense( + chunk_indptr, + chunk_indices, + chunk_data, + out, + col_lb=0, + col_ub=stop - start, + stream=cp.cuda.get_current_stream().ptr, + ) + return out + + +def _csr_tile_to_dense_block(X, start: int, stop: int) -> cp.ndarray: + """Densify a CSR column window into an F-order float64 GPU block. + Device CSR avoids rebuilding a CSR/CSC slice before densifying.""" + from rapids_singlecell._cuda import _rank_stats_cuda as _rs + + n_rows = X.shape[0] + out = cp.zeros((n_rows, stop - start), dtype=cp.float64, order="F") + if X.nnz == 0: + return out + _rs.csr_tile_to_dense( + cp.asarray(X.indptr), + cp.asarray(X.indices), + cp.asarray(X.data), + out, + col_lb=int(start), + col_ub=int(stop), + stream=cp.cuda.get_current_stream().ptr, ) - return _sparse_to_dense(csc_chunk, order="F").astype(cp.float64) + return out def _get_column_block(X, start: int, stop: int) -> cp.ndarray: """Extract a column block as a dense F-order float64 CuPy array.""" match X: + # Device CSR can densify in one pass without transfer. + # Host CSR intentionally falls through to avoid per-chunk full transfers. + case cpsp.csr_matrix(): + return _csr_tile_to_dense_block(X, start, stop) case sp.csc_matrix() | sp.csc_array(): return _csc_columns_to_gpu(X, start, stop, X.shape[0]) case sp.spmatrix() | sp.sparray(): chunk = cpsp.csc_matrix(X[:, start:stop].tocsc()) - return _sparse_to_dense(chunk, order="F").astype(cp.float64) + return _csc_columns_to_gpu(chunk, 0, chunk.shape[1], X.shape[0]) case cpsp.csc_matrix(): return _csc_columns_to_gpu(X, start, stop, X.shape[0]) case cpsp.spmatrix(): - return _sparse_to_dense(X[:, start:stop], order="F").astype(cp.float64) + chunk = cpsp.csc_matrix(X[:, start:stop].tocsc()) + return _csc_columns_to_gpu(chunk, 0, chunk.shape[1], X.shape[0]) case np.ndarray() | cp.ndarray(): return cp.asarray(X[:, start:stop], dtype=cp.float64, order="F") case _: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index c14c760d..a06c6d40 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -1,89 +1,426 @@ from __future__ import annotations import warnings +from dataclasses import dataclass from typing import TYPE_CHECKING import cupy as cp +import cupyx.scipy.sparse as cpsp import cupyx.scipy.special as cupyx_special import numpy as np import scipy.sparse as sp from rapids_singlecell._cuda import _wilcoxon_cuda as _wc -from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc +from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs -from ._utils import _choose_chunk_size, _get_column_block +from ._utils import ( + EPS, + MIN_GROUP_SIZE_WARNING, + _choose_chunk_size, +) if TYPE_CHECKING: from numpy.typing import NDArray from ._core import _RankGenes -MIN_GROUP_SIZE_WARNING = 25 +DEFAULT_WILCOXON_CHUNK_SIZE = 512 +OVR_HOST_CSC_SUB_BATCH = 512 +OVR_HOST_CSR_SUB_BATCH = 2048 +OVR_DEVICE_CSC_SUB_BATCH = 2048 +OVR_DEVICE_CSR_SUB_BATCH = 2048 +OVO_HOST_SPARSE_SUB_BATCH = 256 +OVO_DEVICE_SPARSE_SUB_BATCH = 128 +OVR_DENSE_SUB_BATCH = 64 +OVO_DENSE_TIERED_SUB_BATCH = 256 -def _average_ranks( - matrix: cp.ndarray, *, return_sorted: bool = False -) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: - """ - Compute average ranks for each column using GPU kernel. +@dataclass(frozen=True) +class _OvoContext: + codes: np.ndarray + n_groups: int + ireference: int + n_ref: int + ref_row_ids: np.ndarray + test_group_indices: list[int] + all_grp_row_ids: np.ndarray + offsets_np: np.ndarray + offsets_gpu: cp.ndarray + n_all_grp: int + n_test: int + test_sizes: cp.ndarray - Uses scipy.stats.rankdata 'average' method: ties get the average - of the ranks they would span. - Parameters - ---------- - matrix - Input matrix (n_rows, n_cols) - return_sorted - If True, also return sorted values (useful for tie correction) +def _choose_wilcoxon_chunk_size(requested: int | None, n_genes: int) -> int: + if requested is not None: + return _choose_chunk_size(requested) + return min(DEFAULT_WILCOXON_CHUNK_SIZE, max(1, n_genes)) - Returns - ------- - ranks or (ranks, sorted_vals) - """ - n_rows, n_cols = matrix.shape - # Sort each column - sorter = cp.argsort(matrix, axis=0) - sorted_vals = cp.take_along_axis(matrix, sorter, axis=0) +def _fill_basic_stats_from_accumulators( + rg: _RankGenes, + group_sums: cp.ndarray, + group_nnz: cp.ndarray, + group_sizes: np.ndarray, + *, + n_cells: int, + total_sums: cp.ndarray | None = None, + total_nnz: cp.ndarray | None = None, +) -> None: + # vars left zero: wilcoxon does not output per-group variance. + n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] + means = group_sums / n + rg.means = cp.asnumpy(means) + rg.vars = np.zeros_like(rg.means) + rg.pts = cp.asnumpy(group_nnz / n) if rg.comp_pts else None + + n_rest = cp.float64(n_cells) - n + if total_sums is None: + total_sums = group_sums.sum(axis=0, keepdims=True) + rest_sums = total_sums - group_sums + rest_means = rest_sums / n_rest + rg.means_rest = cp.asnumpy(rest_means) + rg.vars_rest = np.zeros_like(rg.means_rest) + if rg.comp_pts: + if total_nnz is None: + total_nnz = group_nnz.sum(axis=0, keepdims=True) + rg.pts_rest = cp.asnumpy((total_nnz - group_nnz) / n_rest) + else: + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + +def _fill_ovo_stats_from_accumulators( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + group_nnz_slots: cp.ndarray, + *, + group_sizes: NDArray, + test_group_indices: list[int], + n_ref: int, +) -> None: + n_test = len(test_group_indices) + n_genes = int(group_sums_slots.shape[1]) + n_groups = len(rg.groups_order) + slot_group_indices = np.empty(n_test + 1, dtype=np.intp) + slot_group_indices[:n_test] = np.asarray(test_group_indices, dtype=np.intp) + slot_group_indices[n_test] = rg.ireference + slot_sizes = np.empty(n_test + 1, dtype=np.float64) + slot_sizes[:n_test] = group_sizes[slot_group_indices[:n_test]] + slot_sizes[n_test] = n_ref + slot_sizes_dev = cp.asarray(slot_sizes, dtype=cp.float64)[:, None] + + rg.means = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.vars = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.pts = np.zeros((n_groups, n_genes), dtype=np.float64) if rg.comp_pts else None - # Ensure F-order for kernel (columns contiguous in memory) - sorted_vals = cp.asfortranarray(sorted_vals) - sorter = cp.asfortranarray(sorter.astype(cp.int32)) + means_slots = group_sums_slots / slot_sizes_dev + rg.means[slot_group_indices] = cp.asnumpy(means_slots) + # vars left zero: wilcoxon does not output per-group variance. + if rg.comp_pts: + rg.pts[slot_group_indices] = cp.asnumpy(group_nnz_slots / slot_sizes_dev) - stream = cp.cuda.get_current_stream().ptr - _wc.average_rank( - sorted_vals, sorter, matrix, n_rows=n_rows, n_cols=n_cols, stream=stream + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + +def _fill_ovo_dense_stats_from_accumulators( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + group_sum_sq_slots: cp.ndarray, + group_nnz_slots: cp.ndarray, + *, + group_sizes: NDArray, + test_group_indices: list[int], + n_ref: int, +) -> None: + n_test = len(test_group_indices) + n_genes = int(group_sums_slots.shape[1]) + n_groups = len(rg.groups_order) + slot_group_indices = np.empty(n_test + 1, dtype=np.intp) + slot_group_indices[:n_test] = np.asarray(test_group_indices, dtype=np.intp) + slot_group_indices[n_test] = rg.ireference + slot_sizes = np.empty(n_test + 1, dtype=np.float64) + slot_sizes[:n_test] = group_sizes[slot_group_indices[:n_test]] + slot_sizes[n_test] = n_ref + slot_sizes_dev = cp.asarray(slot_sizes, dtype=cp.float64)[:, None] + + rg.means = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.vars = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.pts = np.zeros((n_groups, n_genes), dtype=np.float64) if rg.comp_pts else None + + means_slots = group_sums_slots / slot_sizes_dev + vars_slots = group_sum_sq_slots / slot_sizes_dev - means_slots**2 + vars_slots = cp.where( + slot_sizes_dev > 1.0, + vars_slots * slot_sizes_dev / (slot_sizes_dev - 1.0), + 0.0, ) + rg.means[slot_group_indices] = cp.asnumpy(means_slots) + rg.vars[slot_group_indices] = cp.asnumpy(vars_slots) + if rg.comp_pts: + rg.pts[slot_group_indices] = cp.asnumpy(group_nnz_slots / slot_sizes_dev) + + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + rg._compute_stats_in_chunks = False + - if return_sorted: - return matrix, sorted_vals - return matrix +def _ovo_logfoldchanges_from_sums( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + test_sizes: cp.ndarray, + n_ref: int, +) -> cp.ndarray: + n_test = int(test_sizes.shape[0]) + mean_group = group_sums_slots[:n_test] / test_sizes[:, None] + mean_ref = group_sums_slots[n_test][None, :] / cp.float64(n_ref) + if rg._log1p_base is not None: + scale = cp.float64(np.log(rg._log1p_base)) + group_expr = cp.expm1(mean_group * scale) + ref_expr = cp.expm1(mean_ref * scale) + else: + group_expr = cp.expm1(mean_group) + ref_expr = cp.expm1(mean_ref) + return cp.log2((group_expr + EPS) / (ref_expr + EPS)) + + +def _wilcoxon_scores( + rank_sums: cp.ndarray, + group_sizes: cp.ndarray, + z_scores: cp.ndarray, + *, + return_u_values: bool, +) -> cp.ndarray: + if not return_u_values: + return z_scores + n_group = group_sizes[:, None] + return rank_sums - n_group * (n_group + 1.0) / 2.0 + + +def _z_scores_pvals( + rank_sums: cp.ndarray, + expected: cp.ndarray, + variance: cp.ndarray, + sizes: cp.ndarray, + *, + use_continuity: bool, + return_u_values: bool, +) -> tuple[cp.ndarray, cp.ndarray]: + """Shared Wilcoxon normal-approximation epilogue -> (scores, p_values).""" + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores(rank_sums, sizes, z, return_u_values=return_u_values) + return scores, p_values + + +def _ovr_z_pvals( + rank_sums: cp.ndarray, + group_sizes_dev: cp.ndarray, + rest_sizes: cp.ndarray, + n_cells: int, + tie_corr: cp.ndarray, + *, + use_continuity: bool, + return_u_values: bool, +) -> tuple[cp.ndarray, cp.ndarray]: + """Group-vs-rest scores/p-values (tie_corr is ones when not correcting).""" + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + variance *= (n_cells + 1) / 12.0 + return _z_scores_pvals( + rank_sums, + expected, + variance, + group_sizes_dev, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + + +def _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + *, + use_continuity, + return_u_values, + n_groups, +): + """OVR epilogue: z/p-values -> host -> per-group (idx, scores, pvals).""" + scores, p_values = _ovr_z_pvals( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + scores_host = scores.get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] -def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: - """ - Compute tie correction factor for Wilcoxon test. +def _ovo_z_pvals( + rank_sums: cp.ndarray, + test_sizes: cp.ndarray, + n_ref: int, + tie_corr_arr: cp.ndarray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> tuple[cp.ndarray, cp.ndarray]: + """Group-vs-reference scores/p-values from rank sums and tie correction.""" + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr_arr + return _z_scores_pvals( + rank_sums, + expected, + variance, + test_sizes, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) - Takes pre-sorted values (column-wise) to avoid re-sorting. - Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) - where t is the count of tied values. - """ - n_rows, n_cols = sorted_vals.shape - correction = cp.ones(n_cols, dtype=cp.float64) - if n_rows < 2: - return correction +def _finish_ovo( + rank_sums, + test_sizes, + n_ref, + tie_corr_arr, + *, + tie_correct, + use_continuity, + return_u_values, + rg, + test_group_indices, + logfoldchanges_gpu, +): + """OVO epilogue: z/p-values; stash GPU result if requested, else host tuples.""" + scores, p_values = _ovo_z_pvals( + rank_sums, + test_sizes, + n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + logfoldchanges_gpu, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] - # Ensure F-order - sorted_vals = cp.asfortranarray(sorted_vals) - stream = cp.cuda.get_current_stream().ptr - _wc.tie_correction( - sorted_vals, correction, n_rows=n_rows, n_cols=n_cols, stream=stream +def _host_sparse_data_array(X): + data_dtype = np.dtype(X.data.dtype) + if data_dtype == np.float64: + return X.data + if data_dtype == np.float32 or data_dtype.kind in {"b", "i", "u"}: + return X.data.astype(np.float32, copy=False) + if data_dtype.kind == "c": + msg = ( + "Wilcoxon sparse input data dtype must be real; complex sparse " + "data is not supported." + ) + raise TypeError(msg) + msg = ( + "Wilcoxon sparse input data dtype must be float32, float64, bool, " + f"or integer; got {data_dtype}." ) + raise TypeError(msg) - return correction + +def _validate_wilcoxon_sparse_dtype(X) -> None: + if not (sp.issparse(X) or cpsp.issparse(X)): + return + data_dtype = np.dtype(X.data.dtype) + if data_dtype.kind == "c": + msg = ( + "Wilcoxon sparse input data dtype must be real; complex sparse " + "data is not supported." + ) + raise TypeError(msg) + if cpsp.issparse(X) and data_dtype not in { + np.dtype(np.float32), + np.dtype(np.float64), + }: + msg = ( + "Wilcoxon device sparse input data dtype must be float32 or " + f"float64; got {data_dtype}." + ) + raise TypeError(msg) + if getattr(X, "format", None) in {"csr", "csc"}: + indices_dtype = np.dtype(X.indices.dtype) + indptr_dtype = np.dtype(X.indptr.dtype) + if indices_dtype != indptr_dtype: + msg = ( + "Wilcoxon sparse indices and indptr must have the same dtype; " + f"got indices={indices_dtype} and indptr={indptr_dtype}." + ) + raise TypeError(msg) + if indices_dtype not in {np.dtype(np.int32), np.dtype(np.int64)}: + msg = ( + "Wilcoxon sparse indices and indptr must be int32 or int64; " + f"got {indices_dtype}." + ) + raise TypeError(msg) + + +def _device_sparse_arrays(X): + """Prepare device-sparse arrays for float32-key Wilcoxon kernels. + float64 data is accepted and cast for ranking; stats stay float64.""" + data_dtype = np.dtype(X.data.dtype) + if data_dtype == np.float32: + data = X.data + elif data_dtype == np.float64: + data = X.data.astype(cp.float32, copy=False) + elif data_dtype.kind == "c": + msg = ( + "Wilcoxon device sparse input data dtype must be real; complex " + "sparse data is not supported." + ) + raise TypeError(msg) + else: + msg = ( + "Wilcoxon device sparse input data dtype must be float32 or " + f"float64; got {data_dtype}." + ) + raise TypeError(msg) + + # Keep int64 index buffers native and let the nanobind overloads dispatch by + # dtype. Normal CuPy sparse matrices keep indices and indptr in lockstep. + if X.indices.dtype == cp.int64: + indices = X.indices + indptr = X.indptr + else: + indices = X.indices.astype(cp.int32, copy=False) + indptr = X.indptr.astype(cp.int32, copy=False) + return data, indices, indptr def wilcoxon( @@ -92,16 +429,18 @@ def wilcoxon( tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" - # Compute basic stats - uses Aggregate if on GPU, else defers to chunks - rg._basic_stats() + # Host dense streams column windows; device dense stays device-resident. + # Aggregate stats on GPU, otherwise compute them inside streaming paths. X = rg.X + _validate_wilcoxon_sparse_dtype(X) + rg._basic_stats() n_cells, n_total_genes = rg.X.shape group_sizes = rg.group_sizes if rg.ireference is not None: - # Compare each group against a specific reference group return _wilcoxon_with_reference( rg, X, @@ -110,8 +449,8 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) - # Compare each group against "rest" (all other cells) return _wilcoxon_vs_rest( rg, X, @@ -121,24 +460,41 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) -def _wilcoxon_vs_rest( - rg: _RankGenes, - X, - n_cells: int, - n_total_genes: int, - group_sizes: NDArray, - *, - tie_correct: bool, - use_continuity: bool, - chunk_size: int | None, -) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs rest of cells.""" - n_groups = len(rg.groups_order) +def _host_sparse_format(X) -> str | None: + if not isinstance(X, sp.spmatrix | sp.sparray): + return None + if X.format not in {"csr", "csc"}: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + return X.format + - # Warn for small groups +def _device_sparse_format(X) -> str | None: + if cpsp.isspmatrix_csc(X): + return "csc" + if cpsp.isspmatrix_csr(X): + return "csr" + return None + + +def _host_dense_matrix(X) -> np.ndarray | None: + if not isinstance(X, np.ndarray): + return None + matrix = X + if matrix.dtype.kind != "f" or matrix.dtype.itemsize < 4: + return np.asarray(matrix, dtype=np.float32, order="F") + if matrix.flags.c_contiguous or matrix.flags.f_contiguous: + return matrix + return np.asfortranarray(matrix) + + +def _warn_small_ovr_groups(rg: _RankGenes, group_sizes: NDArray, n_cells: int) -> None: for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: @@ -149,176 +505,1044 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - # Build one-hot indicator matrix from group codes - codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) - group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) - valid_idx = cp.where(codes_gpu < n_groups)[0] - group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 +def _warn_small_ovo_groups( + rg: _RankGenes, ctx: _OvoContext, group_sizes: NDArray +) -> None: + small_groups = [ + str(rg.groups_order[group_index]) + for group_index in ctx.test_group_indices + if int(group_sizes[group_index]) <= MIN_GROUP_SIZE_WARNING + ] + if ctx.n_ref > MIN_GROUP_SIZE_WARNING and not small_groups: + return + parts = [] + if small_groups: + parts.append( + f"{len(small_groups)} test group(s) have size " + f"<= {MIN_GROUP_SIZE_WARNING} (first few: " + f"{', '.join(small_groups[:5])}" + f"{'...' if len(small_groups) > 5 else ''})" + ) + if ctx.n_ref <= MIN_GROUP_SIZE_WARNING: + parts.append(f"reference has size {ctx.n_ref}") + warnings.warn( + f"Small groups detected: {'; '.join(parts)}. normal approximation " + "of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=4, + ) + + +def _build_ovo_context(rg: _RankGenes, group_sizes: NDArray) -> _OvoContext: + codes = rg.group_codes + n_groups = len(rg.groups_order) + ireference = int(rg.ireference) + n_ref = int(group_sizes[ireference]) + ref_row_ids = np.flatnonzero(codes == ireference).astype(np.int32, copy=False) + test_group_indices = [i for i in range(n_groups) if i != ireference] + + offsets = [0] + row_id_parts = [] + for group_index in test_group_indices: + group_rows = np.flatnonzero(codes == group_index).astype(np.int32, copy=False) + row_id_parts.append(group_rows) + offsets.append(offsets[-1] + int(group_rows.size)) + + all_grp_row_ids = ( + np.concatenate(row_id_parts).astype(np.int32, copy=False) + if row_id_parts + else np.empty(0, dtype=np.int32) + ) + offsets_np = np.asarray(offsets, dtype=np.int32) + test_sizes = cp.asarray( + group_sizes[np.asarray(test_group_indices, dtype=np.intp)].astype( + np.float64, copy=False + ) + ) + return _OvoContext( + codes=codes, + n_groups=n_groups, + ireference=ireference, + n_ref=n_ref, + ref_row_ids=ref_row_ids, + test_group_indices=test_group_indices, + all_grp_row_ids=all_grp_row_ids, + offsets_np=offsets_np, + offsets_gpu=cp.asarray(offsets_np), + n_all_grp=int(all_grp_row_ids.size), + n_test=len(test_group_indices), + test_sizes=test_sizes, + ) + + +def _finish_ovo_sparse_stats( + rg: _RankGenes, + ctx: _OvoContext, + group_sums: cp.ndarray, + group_nnz: cp.ndarray, + group_sizes: NDArray, +) -> cp.ndarray | None: + if not rg._compute_stats_in_chunks: + return None + if rg._store_wilcoxon_gpu_result and not rg.comp_pts: + rg._compute_stats_in_chunks = False + return _ovo_logfoldchanges_from_sums( + rg, + group_sums, + ctx.test_sizes, + ctx.n_ref, + ) + _fill_ovo_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes=group_sizes, + test_group_indices=ctx.test_group_indices, + n_ref=ctx.n_ref, + ) + return None + + +def _finish_ovo_dense_stats( + rg: _RankGenes, + ctx: _OvoContext, + group_sums: cp.ndarray, + group_sum_sq: cp.ndarray, + group_nnz: cp.ndarray, + *, + group_sizes: NDArray, +) -> cp.ndarray | None: + if not rg._compute_stats_in_chunks: + return None + if rg._store_wilcoxon_gpu_result and not rg.comp_pts: + rg._compute_stats_in_chunks = False + return _ovo_logfoldchanges_from_sums( + rg, + group_sums, + ctx.test_sizes, + ctx.n_ref, + ) + _fill_ovo_dense_stats_from_accumulators( + rg, + group_sums, + group_sum_sq, + group_nnz, + group_sizes=group_sizes, + test_group_indices=ctx.test_group_indices, + n_ref=ctx.n_ref, + ) + return None + + +def _run_ovr_host_sparse( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + sparse_format = _host_sparse_format(X) + if sparse_format is None: + return None + + n_groups = len(rg.groups_order) + group_codes = rg.group_codes.astype(np.int32, copy=False) + group_sizes_np = group_sizes.astype(np.float64, copy=False) + group_sizes_dev = cp.asarray(group_sizes_np, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + compute_nnz = rg.comp_pts + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_nnz = cp.empty( + (n_groups, n_total_genes) if compute_nnz else (1, 1), + dtype=cp.float64, + ) + compute_totals = bool( + rg._compute_stats_in_chunks and np.any(group_codes == n_groups) + ) + total_sums = cp.empty( + (1, n_total_genes) if compute_totals else (1, 1), + dtype=cp.float64, + ) + total_nnz = cp.empty( + (1, n_total_genes) if (compute_totals and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + + if isinstance(X, sp.spmatrix | sp.sparray) and X.format == "csc": + X.sort_indices() + _wcs.ovr_sparse_csc_host( + _host_sparse_data_array(X), + X.indices, + X.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_nnz, + total_sums, + total_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + compute_totals=compute_totals, + sub_batch_cols=OVR_HOST_CSC_SUB_BATCH, + ) + else: + X.sort_indices() + _wcs.ovr_sparse_csr_host( + _host_sparse_data_array(X), + X.indices, + X.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_nnz, + total_sums, + total_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + compute_totals=compute_totals, + sub_batch_cols=OVR_HOST_CSR_SUB_BATCH, + ) + + if rg._compute_stats_in_chunks: + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes_np, + n_cells=n_cells, + total_sums=total_sums if compute_totals else None, + total_nnz=total_nnz if compute_totals and compute_nnz else None, + ) + + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) + + +def _run_ovr_device_sparse( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + sparse_format = _device_sparse_format(X) + if sparse_format is None: + return None + + X.sort_indices() + data, indices, indptr = _device_sparse_arrays(X) + n_groups = len(rg.groups_order) + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) - chunk_width = _choose_chunk_size(chunk_size) + if sparse_format == "csc": + _wcs.ovr_sparse_csc_device( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSC_SUB_BATCH, + ) + else: + _wcs.ovr_sparse_csr_device( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, + ) - # Accumulate results per group - all_scores: dict[int, list] = {i: [] for i in range(n_groups)} - all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) - # One-time CSR->CSC via fast parallel Numba kernel; _get_column_block - # then uses direct indptr pointer copy for each chunk. - if isinstance(X, sp.spmatrix | sp.sparray): - X = _fast_csr_to_csc(X) if X.format == "csr" else X.tocsc() - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) +def _run_ovr_signed_sparse_dense( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + host_format = ( + _host_sparse_format(X) if isinstance(X, sp.spmatrix | sp.sparray) else None + ) + device_format = _device_sparse_format(X) + sparse_format = host_format or device_format + if sparse_format is None: + return None - # Slice and convert to dense GPU array (F-order for column ops) - block = _get_column_block(X, start, stop) + n_groups = len(rg.groups_order) + group_codes_np = rg.group_codes.astype(np.int32, copy=False) + group_codes_gpu = cp.asarray(group_codes_np, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + chunk_cols = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + + if host_format is not None: + X.sort_indices() + compute_stats = rg._compute_stats_in_chunks + compute_nnz = compute_stats and rg.comp_pts + compute_totals = bool(compute_stats and np.any(group_codes_np == n_groups)) + stats_shape = (n_groups, n_total_genes) if compute_stats else (1, 1) + group_sums = cp.empty(stats_shape, dtype=cp.float64) + group_nnz = cp.empty( + stats_shape if compute_nnz else (1, 1), + dtype=cp.float64, + ) + total_sums = cp.empty( + (1, n_total_genes) if compute_totals else (1, 1), + dtype=cp.float64, + ) + total_nnz = cp.empty( + (1, n_total_genes) if (compute_totals and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + runner = ( + _wcs.ovr_dense_csc_host if host_format == "csc" else _wcs.ovr_dense_csr_host + ) + runner( + _host_sparse_data_array(X), + X.indices, + X.indptr, + group_codes_np, + rank_sums, + tie_corr, + group_sums, + group_nnz, + total_sums, + total_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_stats=compute_stats, + compute_nnz=compute_nnz, + compute_totals=compute_totals, + chunk_cols=chunk_cols, + rank_sub_batch_cols=OVR_DENSE_SUB_BATCH, + ) + if compute_stats: + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes, + n_cells=n_cells, + total_sums=total_sums if compute_totals else None, + total_nnz=total_nnz if compute_totals and compute_nnz else None, + ) + else: + if isinstance(X, cpsp.spmatrix) and X.format == "csr": + X.sort_indices() + data, indices, indptr = _device_sparse_arrays(X) + runner = ( + _wcs.ovr_dense_csc_device + if device_format == "csc" + else _wcs.ovr_dense_csr_device + ) + runner( + data, + indices, + indptr, + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + chunk_cols=chunk_cols, + rank_sub_batch_cols=OVR_DENSE_SUB_BATCH, + ) - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_vs_rest( - block, - start, - stop, - group_matrix=group_matrix, - group_sizes_dev=group_sizes_dev, + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) + + +def _run_ovr_host_dense( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + matrix = _host_dense_matrix(X) + if matrix is None: + return None + n_groups = len(rg.groups_order) + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + compute_nnz = rg.comp_pts + compute_stats = rg._compute_stats_in_chunks + compute_totals = bool(compute_stats and np.any(rg.group_codes == n_groups)) + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = ( + cp.empty(n_total_genes, dtype=cp.float64) + if tie_correct + else cp.ones(n_total_genes, dtype=cp.float64) + ) + stats_shape = (n_groups, n_total_genes) if compute_stats else (1, 1) + group_sums = cp.empty(stats_shape, dtype=cp.float64) + group_nnz = cp.empty( + (n_groups, n_total_genes) if (compute_stats and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + total_sums = cp.empty( + (1, n_total_genes) if compute_totals else (1, 1), + dtype=cp.float64, + ) + total_nnz = cp.empty( + (1, n_total_genes) if (compute_totals and compute_nnz) else (1, 1), + dtype=cp.float64, + ) + _wc.ovr_rank_dense_host_streaming( + matrix, + group_codes_gpu, + rank_sums, + tie_corr, + group_sums, + group_nnz, + total_sums, + total_nnz, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_nnz=compute_stats and compute_nnz, + compute_stats=compute_stats, + compute_totals=compute_totals, + sub_batch_cols=OVR_DENSE_SUB_BATCH, + ) + if compute_stats: + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_nnz, + group_sizes, n_cells=n_cells, + total_sums=total_sums if compute_totals else None, + total_nnz=total_nnz if compute_totals and compute_nnz else None, ) + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) + +def _run_ovr_device_dense( + rg: _RankGenes, + X, + n_cells: int, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + if not isinstance(X, cp.ndarray): + return None + + n_groups = len(rg.groups_order) + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = ( + cp.empty(n_total_genes, dtype=cp.float64) + if tie_correct + else cp.ones(n_total_genes, dtype=cp.float64) + ) + + for start in range(0, n_total_genes, chunk_width): + stop = min(start + chunk_width, n_total_genes) + block_f32 = cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") + n_cols = stop - start + sub_rank_sums = cp.empty((n_groups, n_cols), dtype=cp.float64) + sub_tie_corr = ( + cp.empty(n_cols, dtype=cp.float64) + if tie_correct + else cp.ones(n_cols, dtype=cp.float64) + ) + _wc.ovr_rank_dense_streaming( + block_f32, + group_codes_gpu, + sub_rank_sums, + sub_tie_corr, + n_rows=n_cells, + n_cols=n_cols, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DENSE_SUB_BATCH, + stream=cp.cuda.get_current_stream().ptr, + ) + rank_sums[:, start:stop] = sub_rank_sums if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) - else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - rank_sums = group_matrix.T @ ranks - expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 - variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] - variance *= (n_cells + 1) / 12.0 - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - - z_host = z.get() - p_host = p_values.get() - - for idx in range(n_groups): - all_scores[idx].append(z_host[idx]) - all_pvals[idx].append(p_host[idx]) - - # Collect results per group - return [ - (gi, np.concatenate(all_scores[gi]), np.concatenate(all_pvals[gi])) - for gi in range(n_groups) - ] + tie_corr[start:stop] = sub_tie_corr + + return _finish_ovr( + rank_sums, + group_sizes_dev, + rest_sizes, + n_cells, + tie_corr, + use_continuity=use_continuity, + return_u_values=return_u_values, + n_groups=n_groups, + ) -def _wilcoxon_with_reference( +def _wilcoxon_vs_rest( rg: _RankGenes, X, + n_cells: int, n_total_genes: int, group_sizes: NDArray, *, tie_correct: bool, use_continuity: bool, chunk_size: int | None, + return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs a specific reference group.""" - codes = rg.group_codes - n_ref = int(group_sizes[rg.ireference]) - mask_ref = codes == rg.ireference + """Wilcoxon test: each group vs rest of cells.""" + _warn_small_ovr_groups(rg, group_sizes, n_cells) + match X: + case sp.spmatrix() | sp.sparray(): + if rg._sparse_negative_fallback: + result = _run_ovr_signed_sparse_dense( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + else: + result = _run_ovr_host_sparse( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case _ if _device_sparse_format(X) is not None: + if rg._sparse_negative_fallback: + result = _run_ovr_signed_sparse_dense( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + else: + result = _run_ovr_device_sparse( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case np.ndarray(): + result = _run_ovr_host_dense( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case cp.ndarray(): + result = _run_ovr_device_dense( + rg, + X, + n_cells, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + case _: + msg = f"Unsupported Wilcoxon OVR input type: {type(X)}" + raise TypeError(msg) + if result is not None: + return result + msg = f"Unsupported Wilcoxon OVR input type: {type(X)}" + raise TypeError(msg) - results: list[tuple[int, NDArray, NDArray]] = [] - for group_index in range(len(rg.groups_order)): - if group_index == rg.ireference: - continue +def _run_ovo_host_sparse( + rg: _RankGenes, + X, + ctx: _OvoContext, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + sparse_format = _host_sparse_format(X) + if sparse_format is None: + return None - n_group = int(group_sizes[group_index]) - n_combined = n_group + n_ref + rank_sums = cp.zeros((ctx.n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((ctx.n_test, n_total_genes), dtype=cp.float64) + n_groups_stats = ctx.n_test + 1 + compute_sums = rg._compute_stats_in_chunks + compute_nnz = rg.comp_pts + group_sums = cp.empty( + (n_groups_stats, n_total_genes) + if (compute_sums or sparse_format == "csc") + else (1,), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups_stats, n_total_genes) if compute_nnz else (1,), + dtype=cp.float64, + ) + stats_code_lookup = np.full(ctx.n_groups + 1, n_groups_stats, dtype=np.int32) + test_group_indices_np = np.asarray(ctx.test_group_indices, dtype=np.intp) + stats_code_lookup[test_group_indices_np] = np.arange(ctx.n_test, dtype=np.int32) + stats_code_lookup[ctx.ireference] = ctx.n_test + stats_codes = stats_code_lookup[ctx.codes] - # Warn for small groups - if n_group <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: - warnings.warn( - f"Group {rg.groups_order[group_index]} has size {n_group} " - f"(reference {n_ref}); normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, - ) + if sparse_format == "csc": + X.sort_indices() + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ctx.ref_row_ids] = np.arange(ctx.n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[ctx.all_grp_row_ids] = np.arange(ctx.n_all_grp, dtype=np.int32) + _wcs.ovo_streaming_csc_host( + _host_sparse_data_array(X), + X.indices, + X.indptr, + ref_row_map, + grp_row_map, + ctx.offsets_np, + stats_codes, + rank_sums, + tie_corr_arr, + group_sums, + group_nnz, + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, + n_rows=X.shape[0], + n_cols=n_total_genes, + n_groups=ctx.n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, + ) + else: + X.sort_indices() + _wcs.ovo_streaming_csr_host( + _host_sparse_data_array(X), + X.indices, + X.indptr, + ctx.ref_row_ids, + ctx.all_grp_row_ids, + ctx.offsets_np, + rank_sums, + tie_corr_arr, + group_sums, + group_nnz, + n_full_rows=X.shape[0], + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, + n_cols=n_total_genes, + n_test=ctx.n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + compute_sums=compute_sums, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, + ) - # Combined mask: group + reference - mask_obs = codes == group_index - mask_combined = mask_obs | mask_ref + logfoldchanges_gpu = _finish_ovo_sparse_stats( + rg, ctx, group_sums, group_nnz, group_sizes + ) + return _finish_ovo( + rank_sums, + ctx.test_sizes, + ctx.n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=ctx.test_group_indices, + logfoldchanges_gpu=logfoldchanges_gpu, + ) - # Subset matrix ONCE before chunking (10x faster than filtering each chunk) - X_subset = X[mask_combined, :] - # One-time CSR->CSC via fast parallel Numba kernel - if isinstance(X_subset, sp.spmatrix | sp.sparray): - X_subset = ( - _fast_csr_to_csc(X_subset) - if X_subset.format == "csr" - else X_subset.tocsc() - ) +def _run_ovo_device_sparse( + rg: _RankGenes, + X, + ctx: _OvoContext, + n_total_genes: int, + _group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + sparse_format = _device_sparse_format(X) + if sparse_format is None: + return None + + if isinstance(X, cpsp.spmatrix) and X.format == "csr": + X.sort_indices() + data, indices, indptr = _device_sparse_arrays(X) + rank_sums = cp.zeros((ctx.n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((ctx.n_test, n_total_genes), dtype=cp.float64) + + if sparse_format == "csc": + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ctx.ref_row_ids] = np.arange(ctx.n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[ctx.all_grp_row_ids] = np.arange(ctx.n_all_grp, dtype=np.int32) + _wcs.ovo_streaming_csc_device( + data, + indices, + indptr, + cp.asarray(ref_row_map), + cp.asarray(grp_row_map), + ctx.offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, + n_cols=n_total_genes, + n_groups=ctx.n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) + else: + _wcs.ovo_streaming_csr_device( + data, + indices, + indptr, + cp.asarray(ctx.ref_row_ids, dtype=cp.int32), + cp.asarray(ctx.all_grp_row_ids, dtype=cp.int32), + ctx.offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, + n_cols=n_total_genes, + n_groups=ctx.n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) + + return _finish_ovo( + rank_sums, + ctx.test_sizes, + ctx.n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=ctx.test_group_indices, + logfoldchanges_gpu=None, + ) + + +def _run_ovo_host_dense( + rg: _RankGenes, + X, + ctx: _OvoContext, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + matrix = _host_dense_matrix(X) + if matrix is None: + return None + dense_sub_batch_cols = ( + _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + if chunk_size is not None + else OVO_DENSE_TIERED_SUB_BATCH + ) + rank_sums = cp.zeros((ctx.n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((ctx.n_test, n_total_genes), dtype=cp.float64) + compute_stats = rg._compute_stats_in_chunks + compute_nnz = compute_stats and rg.comp_pts + n_groups_stats = ctx.n_test + 1 + stats_shape = (n_groups_stats, n_total_genes) if compute_stats else (1, 1) + group_sums = cp.empty(stats_shape, dtype=cp.float64) + group_sum_sq = cp.empty(stats_shape, dtype=cp.float64) + group_nnz = cp.empty( + stats_shape if compute_nnz else (1, 1), + dtype=cp.float64, + ) + _wc.ovo_rank_dense_host_streaming( + matrix, + ctx.ref_row_ids, + ctx.all_grp_row_ids, + ctx.offsets_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sum_sq, + group_nnz, + n_groups=ctx.n_test, + compute_tie_corr=tie_correct, + compute_nnz=compute_nnz, + compute_stats=compute_stats, + sub_batch_cols=dense_sub_batch_cols, + ) + logfoldchanges_gpu = _finish_ovo_dense_stats( + rg, + ctx, + group_sums, + group_sum_sq, + group_nnz, + group_sizes=group_sizes, + ) + return _finish_ovo( + rank_sums, + ctx.test_sizes, + ctx.n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=ctx.test_group_indices, + logfoldchanges_gpu=logfoldchanges_gpu, + ) - # Within the combined array, True = group cell, False = reference cell - group_mask_gpu = cp.asarray(mask_obs[mask_combined]) - chunk_width = _choose_chunk_size(chunk_size) +def _run_ovo_device_dense( + rg: _RankGenes, + X, + ctx: _OvoContext, + n_total_genes: int, + _group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]] | None: + if not isinstance(X, cp.ndarray): + return None - # Pre-allocate output arrays - scores = np.empty(n_total_genes, dtype=np.float64) - pvals = np.empty(n_total_genes, dtype=np.float64) + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + ref_rows = cp.asarray(ctx.ref_row_ids, dtype=cp.int32) + grp_rows = cp.asarray(ctx.all_grp_row_ids, dtype=cp.int32) + rank_sums = cp.empty((ctx.n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((ctx.n_test, n_total_genes), dtype=cp.float64) + for start in range(0, n_total_genes, chunk_width): + stop = min(start + chunk_width, n_total_genes) + n_cols = stop - start + ref_f32 = cp.asarray(X[ref_rows, start:stop], dtype=cp.float32, order="F") + grp_f32 = cp.asarray(X[grp_rows, start:stop], dtype=cp.float32, order="F") + sub_rank_sums = cp.empty((ctx.n_test, n_cols), dtype=cp.float64) + sub_tie_corr = cp.ones((ctx.n_test, n_cols), dtype=cp.float64) + _wc.ovo_rank_dense_tiered_unsorted_ref( + ref_f32, + grp_f32, + ctx.offsets_gpu, + sub_rank_sums, + sub_tie_corr, + n_ref=ctx.n_ref, + n_all_grp=ctx.n_all_grp, + n_cols=n_cols, + n_groups=ctx.n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DENSE_TIERED_SUB_BATCH, + stream=cp.cuda.get_current_stream().ptr, + ) + rank_sums[:, start:stop] = sub_rank_sums + if tie_correct: + tie_corr_arr[:, start:stop] = sub_tie_corr - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) + return _finish_ovo( + rank_sums, + ctx.test_sizes, + ctx.n_ref, + tie_corr_arr, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + rg=rg, + test_group_indices=ctx.test_group_indices, + logfoldchanges_gpu=None, + ) - # Get block for combined cells only - block = _get_column_block(X_subset, start, stop) - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_with_ref( - block, - start, - stop, - group_index=group_index, - group_mask_gpu=group_mask_gpu, - n_group=n_group, - n_ref=n_ref, +def _wilcoxon_with_reference( + rg: _RankGenes, + X, + n_total_genes: int, + group_sizes: NDArray, + *, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + return_u_values: bool, +) -> list[tuple[int, NDArray, NDArray]]: + """Wilcoxon test: all selected groups vs a specific reference group.""" + ctx = _build_ovo_context(rg, group_sizes) + if ctx.n_test == 0: + return [] + _warn_small_ovo_groups(rg, ctx, group_sizes) + match X: + case sp.spmatrix() | sp.sparray(): + result = _run_ovo_host_sparse( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, ) - - # Ranks for combined group+reference cells - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) - else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - # Rank sum for the group - rank_sums = (ranks * group_mask_gpu[:, None]).sum(axis=0) - - # Wilcoxon z-score formula for two groups - expected = n_group * (n_combined + 1) / 2.0 - variance = tie_corr * n_group * n_ref * (n_combined + 1) / 12.0 - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - - # Fill pre-allocated arrays - scores[start:stop] = z.get() - pvals[start:stop] = p_values.get() - - results.append((group_index, scores, pvals)) - - return results + case _ if _device_sparse_format(X) is not None: + result = _run_ovo_device_sparse( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + return_u_values=return_u_values, + ) + case np.ndarray(): + result = _run_ovo_host_dense( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + case cp.ndarray(): + result = _run_ovo_device_dense( + rg, + X, + ctx, + n_total_genes, + group_sizes, + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + case _: + msg = f"Unsupported Wilcoxon OVO input type: {type(X)}" + raise TypeError(msg) + if result is not None: + return result + msg = f"Unsupported Wilcoxon OVO input type: {type(X)}" + raise TypeError(msg) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index fa4bbccf..4e1236ad 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -11,6 +11,8 @@ from rapids_singlecell._compat import DaskArray from rapids_singlecell._cuda import _wilcoxon_binned_cuda as _wb +from ._utils import MIN_GROUP_SIZE_WARNING, _get_column_block + if TYPE_CHECKING: from numpy.typing import NDArray @@ -23,11 +25,7 @@ def _fill_sparse_zero_bin(hist: cp.ndarray, group_counts: cp.ndarray) -> None: - """Fill bin 0 with zero counts for sparse histograms (in-place). - - Sparse kernels only populate bins 1..n_bins (nonzero values). - Bin 0 = group_size - sum(bins 1..n_bins) for each gene/group. - """ + """Fill sparse histogram bin 0 from group size minus nonzero-bin counts.""" nonzero_per_group = hist.sum(axis=2) # (n_genes, n_groups) hist[:, :, 0] = group_counts[None, :].astype(cp.uint32) - nonzero_per_group @@ -71,39 +69,7 @@ def wilcoxon_binned( chunk_size: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, ) -> list[tuple[int, NDArray, NDArray]]: - """Histogram-based approximate Wilcoxon rank-sum test. - - Approximates ranks by discretizing expression values into ``n_bins`` - fixed-width bins, then computing rank sums from cumulative histogram - counts. This avoids the O(n log n) per-gene sort required by exact - Wilcoxon, making it feasible for datasets with millions of cells and - compatible with Dask arrays. - - Supports both one-vs-rest (``reference='rest'``) and one-vs-one - (``reference=''``) comparisons. - - Parameters - ---------- - rg - The _RankGenes instance. - tie_correct - Adjust the variance for ties. In the binned approach each bin - acts as a tie group, so the correction uses the bin counts - directly. - n_bins - Number of histogram bins. Higher = better approximation. - Default is 1000 for in-memory arrays and 200 for Dask arrays. - chunk_size - Genes processed per GPU batch. Controls peak GPU memory. - bin_range - How to determine the histogram bin range. - ``None`` (default) uses ``'auto'`` for in-memory arrays and - ``'log1p'`` for Dask arrays (to avoid a costly data scan). - ``'log1p'`` uses a fixed [0, 15] range suitable for - log1p-normalized data. - ``'auto'`` computes the actual (min, max) of the data. Use this - for z-scored or unnormalized data. - """ + """Histogram-based approximate Wilcoxon rank-sum test.""" if not rg.is_log1p: warnings.warn( "wilcoxon_binned expects log-normalized data " @@ -119,36 +85,43 @@ def wilcoxon_binned( if n_bins is None: n_bins = _DASK_N_BINS if isinstance(X, DaskArray) else _DEFAULT_N_BINS - # Sparse kernels assume non-negative data (pre-fill+correct pattern). - # Dense kernel handles any range. - # NOTE: Dask sparse is not validated here because checking .data.min() - # would require materializing all blocks. The sparse histogram kernels - # will silently produce incorrect results for negative Dask sparse data. - if not isinstance(X, DaskArray) and cpsp.issparse(X) and X.nnz > 0: - if float(X.data.min()) < 0: - msg = ( - "Sparse input contains negative values. The sparse histogram " - "kernels assume non-negative data. Convert to dense or use " - "bin_range='auto' with a dense array." - ) - raise ValueError(msg) - n_groups = len(rg.groups_order) n_cells, n_genes = X.shape group_sizes = rg.group_sizes - # group_codes: 0..n_groups-1 for selected cells, n_groups (sentinel) - # for unselected. For vs-rest, unselected cells are binned into a - # dummy group so they contribute to total counts for correct midranks. - # For vs-reference, the kernel bounds guard (grp >= n_groups) skips them. + # Dask sparse cannot bin negatives correctly because implicit zeros use bin 0. + # Refuse instead of silently mis-ranking; in-memory sparse uses dense fallback. + if isinstance(X, DaskArray) and cpsp.issparse(X._meta): + + def _block_data_min(block): + if block.nnz > 0: + return block.data.min().reshape(1) + return cp.zeros(1, dtype=block.dtype) + + data_min = float( + X.map_blocks( + _block_data_min, + dtype=X.dtype, + drop_axis=1, + chunks=((1,) * len(X.chunks[0]),), + ) + .min() + .compute() + ) + if data_min < 0: + raise ValueError( + "wilcoxon_binned does not support negative values in Dask " + "sparse input; the binned approximation mis-ranks implicit " + "zeros. Densify the data or use a nonnegative representation." + ) + + # group_codes use n_groups as sentinel for unselected cells. + # vs-rest bins sentinels for totals; vs-reference kernels skip them. group_codes_np = rg.group_codes has_unselected = bool(np.any(group_codes_np == n_groups)) - # For one-vs-one with a group subset, only the selected groups' cells - # matter for pairwise rankings. Filter X down so kernels don't iterate - # over irrelevant cells. For Dask we can't cheaply subset rows, but - # the kernel bounds guard (grp >= n_groups → skip) avoids wasted - # atomicAdds, so we just clear the flag without allocating a dummy group. + # One-vs-one only ranks selected groups; filter in-memory rows. + # Dask keeps rows but kernels skip sentinels, avoiding dummy-group atomics. if ireference is not None and has_unselected: if isinstance(X, DaskArray): has_unselected = False @@ -173,7 +146,7 @@ def wilcoxon_binned( ): if gi == ireference: continue - if size <= 25 or n_ref <= 25: + if size <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: warnings.warn( f"Group {name} has size {size} (reference {n_ref}); normal " "approximation of the Wilcoxon statistic may be inaccurate.", @@ -183,7 +156,7 @@ def wilcoxon_binned( else: for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size - if size <= 25 or rest <= 25: + if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: warnings.warn( f"Group {name} has size {size} (rest {rest}); normal " "approximation of the Wilcoxon statistic may be inaccurate.", @@ -195,6 +168,17 @@ def wilcoxon_binned( if bin_range is None: bin_range = "log1p" if isinstance(X, DaskArray) else "auto" + # The fixed log1p range assumes nonnegative data. + # Signed sparse fallback needs data-driven auto range to avoid clamping. + if rg._sparse_negative_fallback and bin_range == "log1p": + warnings.warn( + "bin_range='log1p' is invalid for sparse input with negative values " + "(the fixed [0, 15] range would clamp them); using bin_range='auto'.", + RuntimeWarning, + stacklevel=4, + ) + bin_range = "auto" + # Prepare GPU arrays and bin arithmetic if bin_range == "auto": bin_low, bin_high = _data_range(X) @@ -224,6 +208,7 @@ def wilcoxon_binned( "tie_correct": tie_correct, "use_continuity": use_continuity, "ireference": ireference, + "force_dense": rg._sparse_negative_fallback, } # Pre-allocate output @@ -270,13 +255,25 @@ def process_gene_batch( tie_correct: bool = False, use_continuity: bool = False, ireference: int | None = None, + force_dense: bool = False, ) -> tuple[cp.ndarray, cp.ndarray]: """Process one gene batch, dispatching on Dask vs in-memory.""" n_hist_groups = n_cells_per_group_hist.shape[0] n_genes_batch = stop - start is_sparse = False - if isinstance(X, DaskArray): + if force_dense and cpsp.issparse(X): + # Negative sparse fallback: bin 0 is only correct for nonnegative data. + # Densify the column window so dense bins span the full [min, max]. + hist = _launch_dense( + _get_column_block(X, start, stop), + group_codes, + n_hist_groups, + n_bins=n_bins, + bin_low=bin_low, + inv_bin_width=inv_bin_width, + ) + elif isinstance(X, DaskArray): hist = _process_dask( X, start=start, @@ -403,12 +400,7 @@ def _compute_stats_vs_ref( tie_correct: bool = False, use_continuity: bool = False, ) -> tuple[cp.ndarray, cp.ndarray]: - """Compute Wilcoxon z-scores for each group vs a specific reference. - - For each group *g*, midranks are derived from the pairwise histogram - ``hist_g + hist_ref`` so that only cells in the compared pair - contribute to the ranking. - """ + """Compute Wilcoxon z-scores for each group vs a specific reference.""" # hist shape: (n_genes, n_groups, n_bins_total) ref_hist = hist[:, ireference : ireference + 1, :] # (n_genes, 1, n_bins_total) @@ -556,13 +548,7 @@ def _process_dask( inv_bin_width: float, n_bins_total: int, ) -> cp.ndarray: - """Build histogram from a Dask array. - - Receives the full (unsliced) Dask array and column range - ``[start, stop)``. Column selection happens inside each block - handler on the materialised CuPy chunk, keeping the Dask graph - simple (no column-slice node per gene batch). - """ + """Build a column-range histogram from an unsliced Dask array.""" import dask.array as da if cpsp.isspmatrix_csr(X._meta): diff --git a/tests/dask/test_dask_rank_wilcoxon_binned.py b/tests/dask/test_dask_rank_wilcoxon_binned.py index 5f49fb6b..dc23b317 100644 --- a/tests/dask/test_dask_rank_wilcoxon_binned.py +++ b/tests/dask/test_dask_rank_wilcoxon_binned.py @@ -137,3 +137,72 @@ def test_wilcoxon_binned_dask_reference(client, data_kind): ) _compare_scores(adata.uns["rank_genes_groups"], dask_data.uns["rank_genes_groups"]) + + +@pytest.mark.parametrize("data_kind", ["sparse", "dense"]) +def test_wilcoxon_binned_dask_auto_range(client, data_kind): + """bin_range='auto' exercises the Dask _data_range branches (per-block + min/max via map_blocks), which the bin_range='log1p' tests never reach.""" + adata, dask_data, groupby = _setup_data(data_kind) + + for ad_ in (adata, dask_data): + rsc.tl.rank_genes_groups( + ad_, + groupby=groupby, + method="wilcoxon_binned", + n_bins=200, + bin_range="auto", + use_raw=False, + ) + + _compare_scores(adata.uns["rank_genes_groups"], dask_data.uns["rank_genes_groups"]) + + +@pytest.mark.parametrize("data_kind", ["sparse", "dense"]) +def test_wilcoxon_binned_dask_reference_subset(client, data_kind): + """Dask + reference + groups-subset together (the has_unselected Dask + branch where unselected cells coexist with a reference group).""" + adata, dask_data, groupby = _setup_data(data_kind) + cats = [str(c) for c in adata.obs[groupby].cat.categories] + groups = cats[1:4] # subset that excludes the reference -> unselected cells exist + reference = cats[0] + + for ad_ in (adata, dask_data): + rsc.tl.rank_genes_groups( + ad_, + groupby=groupby, + method="wilcoxon_binned", + groups=groups, + reference=reference, + n_bins=1000, + bin_range="log1p", + use_raw=False, + ) + + _compare_scores(adata.uns["rank_genes_groups"], dask_data.uns["rank_genes_groups"]) + + +def test_wilcoxon_binned_dask_negative_sparse_raises(client): + """Dask sparse input with a stored negative is refused (the binned histogram + cannot place implicit zeros correctly for signed data).""" + import anndata as ad_mod + import pandas as pd + import scipy.sparse as sp + + rng = np.random.default_rng(0) + X = np.abs(rng.standard_normal((60, 8))).astype(np.float32) + X[X < 0.5] = 0.0 + X[0, 0] = -1.0 # one stored negative + obs = pd.DataFrame({"g": pd.Categorical([f"{i % 3}" for i in range(60)])}) + var = pd.DataFrame(index=[f"v{i}" for i in range(8)]) + adata = ad_mod.AnnData(X=sp.csr_matrix(X), obs=obs, var=var) + adata.X = as_sparse_cupy_dask_array(adata.X).persist() + + with pytest.raises(ValueError, match="negative values in Dask sparse"): + rsc.tl.rank_genes_groups( + adata, + groupby="g", + method="wilcoxon_binned", + bin_range="auto", + use_raw=False, + ) diff --git a/tests/test_rank_genes_groups_ttest.py b/tests/test_rank_genes_groups_ttest.py index e1684536..03500899 100644 --- a/tests/test_rank_genes_groups_ttest.py +++ b/tests/test_rank_genes_groups_ttest.py @@ -1,5 +1,6 @@ from __future__ import annotations +import anndata as ad import numpy as np import pytest import scanpy as sc @@ -10,6 +11,10 @@ import rapids_singlecell as rsc +def _make_nonnegative(adata): + adata.X = np.abs(adata.X) + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("method", ["t-test", "t-test_overestim_var"]) @pytest.mark.parametrize("sparse", [True, False]) @@ -18,11 +23,15 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) if sparse: + adata_gpu.X = adata_gpu.X.astype(np.float32) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() + if sparse: + adata_cpu.X = adata_cpu.X.astype(np.float64) rsc.tl.rank_genes_groups( adata_gpu, @@ -75,6 +84,7 @@ def test_rank_genes_groups_ttest_honors_layer_and_use_raw(reference, method): np.random.seed(42) base = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=150) base.obs["blobs"] = base.obs["blobs"].astype("category") + _make_nonnegative(base) base.layers["signal"] = base.X.copy() ref_adata = base.copy() @@ -123,6 +133,7 @@ def test_rank_genes_groups_ttest_subset_and_bonferroni(reference, method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=4, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) groups = ["0", "1", "2"] if reference != "rest" else ["0", "2"] @@ -218,6 +229,7 @@ def test_rank_genes_groups_ttest_with_renamed_categories( np.random.seed(42) adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # First run with original category names rsc.tl.rank_genes_groups(adata, "blobs", method=method, reference=reference_before) @@ -242,10 +254,12 @@ def test_rank_genes_groups_ttest_with_renamed_categories( @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("method", ["t-test", "t-test_overestim_var"]) def test_rank_genes_groups_ttest_with_unsorted_groups(reference, method): - """Test that group order doesn't affect results.""" + """Group order sets the output column order (matching scanpy); the per-group + statistics themselves are order-independent.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) bdata = adata.copy() groups = ["0", "1", "2", "3"] if reference != "rest" else ["0", "2", "3"] @@ -258,9 +272,13 @@ def test_rank_genes_groups_ttest_with_unsorted_groups(reference, method): bdata, "blobs", method=method, groups=groups_reversed, reference=reference ) - expected_groups = {g for g in groups if g != reference} - assert set(adata.uns["rank_genes_groups"]["names"].dtype.names) == expected_groups - assert set(bdata.uns["rank_genes_groups"]["names"].dtype.names) == expected_groups + # Column order echoes the user-provided group order (reference excluded). + assert adata.uns["rank_genes_groups"]["names"].dtype.names == tuple( + g for g in groups if g != reference + ) + assert bdata.uns["rank_genes_groups"]["names"].dtype.names == tuple( + g for g in groups_reversed if g != reference + ) # Pick a group that's not the reference for comparison test_group = "3" if reference != "3" else "0" @@ -285,6 +303,7 @@ def test_rank_genes_groups_ttest_pts(reference, method): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) adata_cpu = adata_gpu.copy() # Run with pts=True @@ -341,13 +360,7 @@ def test_rank_genes_groups_ttest_pts(reference, method): def test_rank_genes_groups_ttest_direct_scipy(): - """Test t-test scores directly against scipy.stats.ttest_ind on two matrices. - - Creates a simple two-group dataset and compares rapids_singlecell t-test - directly against scipy.stats.ttest_ind without intermediate statistics. - """ - import anndata as ad - + """Compare rapids_singlecell t-test scores directly to scipy.stats.ttest_ind.""" np.random.seed(42) n_group1, n_group2, n_genes = 50, 60, 20 @@ -357,6 +370,9 @@ def test_rank_genes_groups_ttest_direct_scipy(): # Combine into AnnData X = np.vstack([X_group1, X_group2]) + X -= X.min() + X_group1 = X[:n_group1] + X_group2 = X[n_group1:] obs = {"group": ["A"] * n_group1 + ["B"] * n_group2} adata = ad.AnnData(X=X, obs=obs) adata.obs["group"] = adata.obs["group"].astype("category") @@ -390,15 +406,11 @@ def test_rank_genes_groups_ttest_direct_scipy(): def test_rank_genes_groups_ttest_matches_scipy(): - """Test that t-test scores match scipy computation directly. - - This test verifies that our variance clipping fix produces correct results - by comparing against scipy.stats.ttest_ind_from_stats with properly computed - (non-negative) variances. Uses real pbmc68k_reduced dataset at float64 precision. - """ + """Compare t-test scores to scipy stats with nonnegative variances.""" adata = pbmc68k_reduced() # Convert to float64 for maximum precision in comparison adata.X = adata.X.astype(np.float64) + _make_nonnegative(adata) # Run rapids_singlecell t-test rsc.tl.rank_genes_groups(adata, "bulk_labels", method="t-test", use_raw=False) @@ -461,6 +473,7 @@ def test_rank_genes_groups_ttest_mask_var_array(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # Create mask to select only first 5 genes mask = np.array([True] * 5 + [False] * 5) @@ -488,6 +501,7 @@ def test_rank_genes_groups_ttest_mask_var_string(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # Add mask column to adata.var adata.var["highly_variable"] = [True] * 6 + [False] * 4 @@ -514,6 +528,7 @@ def test_rank_genes_groups_ttest_mask_var_matches_scanpy(method): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=8, n_centers=3, n_observations=150) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) adata_cpu = adata_gpu.copy() mask = np.array([True, False, True, False, True, True, False, True]) @@ -546,6 +561,7 @@ def test_rank_genes_groups_ttest_rankby_abs(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) adata_abs = adata.copy() # Run without rankby_abs @@ -573,6 +589,7 @@ def test_rank_genes_groups_ttest_key_added(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) custom_key = "my_custom_key" diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 7f32f0e5..9a166611 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1,16 +1,306 @@ from __future__ import annotations import cupy as cp +import cupyx.scipy.sparse as cpsp import numpy as np import pandas as pd import pytest import scanpy as sc import scipy.sparse as sp -from scipy.stats import mannwhitneyu, rankdata, tiecorrect +from scipy.stats import mannwhitneyu import rapids_singlecell as rsc +def _to_format(X_dense, fmt): + if fmt == "numpy_dense": + return np.asarray(X_dense) + if fmt == "scipy_csr": + return sp.csr_matrix(X_dense) + if fmt == "scipy_csc": + return sp.csc_matrix(X_dense) + if fmt == "cupy_dense": + return cp.asarray(X_dense) + if fmt == "cupy_csr": + return cpsp.csr_matrix(cp.asarray(X_dense)) + if fmt == "cupy_csc": + return cpsp.csc_matrix(cp.asarray(X_dense)) + raise ValueError(f"Unknown format: {fmt}") + + +def _make_nonnegative(adata): + adata.X = np.abs(np.asarray(adata.X)).astype(np.float32) + return adata + + +# Sparse Wilcoxon negative values must fall back to dense full-sort ranking. +# Covers Wilcoxon OVR/OVO and binned OVR; other methods accept signed sparse. +@pytest.mark.parametrize( + ("method", "reference"), + [("wilcoxon", "rest"), ("wilcoxon_binned", "rest"), ("wilcoxon", "b")], +) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_rank_genes_groups_sparse_negative_values_fallback(method, reference, fmt): + X = np.array( + [ + [-1.0, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + [-2.0, 1.0, 0.0], + [1.0, 0.0, 3.0], + ], + dtype=np.float64, + ) + obs = pd.DataFrame({"group": pd.Categorical(list("aaabbb"), categories=["a", "b"])}) + var = pd.DataFrame(index=["g0", "g1", "g2"]) + + sparse_adata = sc.AnnData(X=_to_format(X, fmt), obs=obs.copy(), var=var.copy()) + dense_fmt = "cupy_dense" if fmt.startswith("cupy") else "numpy_dense" + dense_adata = sc.AnnData(X=_to_format(X, dense_fmt), obs=obs.copy(), var=var.copy()) + + kw = {"method": method, "reference": reference, "use_raw": False} + rsc.tl.rank_genes_groups(sparse_adata, "group", **kw) + rsc.tl.rank_genes_groups(dense_adata, "group", **kw) + + # Sparse-with-negatives falls back to the dense ranking -> identical result. + sp_scores = sparse_adata.uns["rank_genes_groups"]["scores"] + dn_scores = dense_adata.uns["rank_genes_groups"]["scores"] + for group in sp_scores.dtype.names: + np.testing.assert_allclose( + np.asarray(sp_scores[group], dtype=float), + np.asarray(dn_scores[group], dtype=float), + rtol=1e-13, + atol=1e-13, + ) + + +@pytest.mark.parametrize("layout", ["csr", "csc"]) +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_device_sparse_int64_indptr_matches_scanpy(layout, reference): + # Real int64 indptr needs nnz > 2^31, so CI promotes a small matrix. + # cupy >= 14.1 preserves the promoted int64 buffers for overload coverage. + rng = np.random.default_rng(0) + dense = np.abs(rng.standard_normal((150, 8))).astype(np.float32) + dense[dense < 0.5] = 0.0 + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(150)])}) + var = pd.DataFrame(index=[f"g{j}" for j in range(8)]) + + ctor = cpsp.csr_matrix if layout == "csr" else cpsp.csc_matrix + mat = ctor(cp.asarray(dense)) + mat.indptr = mat.indptr.astype(cp.int64) + mat.indices = mat.indices.astype(cp.int64) + assert mat.indptr.dtype == cp.int64 + + adata = sc.AnnData(X=mat, obs=obs.copy(), var=var.copy()) + adata_cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + "n_genes": 8, + } + rsc.tl.rank_genes_groups(adata, "group", **kw) + sc.tl.rank_genes_groups(adata_cpu, "group", **kw) + g = adata.uns["rank_genes_groups"] + c = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals", "pvals_adj"): + for grp in g[field].dtype.names: + np.testing.assert_allclose( + np.asarray(g[field][grp], dtype=float), + np.asarray(c[field][grp], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +def test_rank_genes_groups_structured_results_get_df_and_h5ad_match_scanpy(tmp_path): + np.random.seed(42) + adata_rsc = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=120) + _make_nonnegative(adata_rsc) + adata_rsc.obs["blobs"] = adata_rsc.obs["blobs"].astype("category") + adata_rsc.X = sp.csr_matrix(adata_rsc.X) + adata_cpu = adata_rsc.copy() + adata_cpu.X = adata_cpu.X.toarray() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "reference": "1", + "use_raw": False, + "tie_correct": True, + "n_genes": 4, + } + rsc.tl.rank_genes_groups(adata_rsc, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + rsc_result = adata_rsc.uns["rank_genes_groups"] + assert isinstance(rsc_result["names"], np.ndarray) + assert rsc_result["names"].dtype.names == ("0", "2") + assert tuple(rsc_result["names"][0]) == tuple( + adata_cpu.uns["rank_genes_groups"]["names"][0] + ) + np.testing.assert_array_equal( + rsc_result["names"].copy(), + np.asarray(rsc_result["names"]), + ) + + h5ad_path = tmp_path / "rank_genes_groups.h5ad" + adata_rsc.write_h5ad(h5ad_path) + adata_rsc = sc.read_h5ad(h5ad_path) + + rsc_df = sc.get.rank_genes_groups_df(adata_rsc, group=None) + scanpy_df = sc.get.rank_genes_groups_df(adata_cpu, group=None) + pd.testing.assert_frame_equal(rsc_df, scanpy_df) + + +def test_rank_genes_groups_return_format_removed(): + adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) + _make_nonnegative(adata) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(TypeError, match="return_format has been removed"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + return_format="arrays", + ) + + +@pytest.mark.parametrize("reference", ["rest", "b"]) +@pytest.mark.parametrize( + "fmt", + ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_dense", "cupy_csr", "cupy_csc"], +) +def test_rank_genes_groups_wilcoxon_return_u_values(reference, fmt): + X = np.array( + [ + [5.0, 0.0, 1.0, 2.0], + [4.0, 0.0, 1.0, 2.0], + [1.0, 3.0, 2.0, 2.0], + [0.0, 2.0, 2.0, 2.0], + [2.0, 1.0, 0.0, 3.0], + [3.0, 1.0, 0.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "b", "b", "c", "c"]) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference=reference, + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + result = adata.uns["rank_genes_groups"] + assert result["params"]["return_u_values"] is True + assert result["scores"].dtype["a"] == np.dtype("float64") + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + mask_group = labels == "a" + mask_ref = labels != "a" if reference == "rest" else labels == reference + expected = np.array( + [ + mannwhitneyu( + X[mask_group, gene], + X[mask_ref, gene], + alternative="two-sided", + ).statistic + for gene in range(X.shape[1]) + ], + dtype=np.float64, + ) + + gene_to_idx = {name: idx for idx, name in enumerate(adata.var_names)} + expected_sorted = np.array([expected[gene_to_idx[name]] for name in df["names"]]) + np.testing.assert_allclose(df["scores"].to_numpy(), expected_sorted) + + +def test_rank_genes_groups_wilcoxon_dense_edge_cases_match_scipy(): + X = np.array( + [ + [1.0, 5.0, 0.0, 2.0, 1.0], + [2.0, 5.0, 0.0, 2.0, 1.0], + [3.0, 5.0, 1.0, 2.0, 1.0], + [4.0, 5.0, 1.0, 3.0, 2.0], + [5.0, 5.0, 1.0, 3.0, 2.0], + [6.0, 5.0, 2.0, 3.0, 2.0], + [7.0, 5.0, 2.0, 4.0, 3.0], + [8.0, 5.0, 2.0, 4.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "a", "a", "b", "b", "b", "b"]) + adata = sc.AnnData( + X=X, + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=["no_ties", "all_ties", "zero_ties", "mixed", "pairs"]), + ) + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference="b", + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + expected_u = {} + for idx, name in enumerate(adata.var_names): + result = mannwhitneyu( + X[labels == "a", idx], + X[labels == "b", idx], + alternative="two-sided", + method="asymptotic", + use_continuity=True, + ) + expected_u[name] = result.statistic + + np.testing.assert_allclose( + df["scores"].to_numpy(), + np.array([expected_u[name] for name in df["names"]]), + rtol=1e-13, + atol=1e-15, + ) + assert np.isfinite(df["pvals"]).all() + + +def test_rank_genes_groups_return_u_values_requires_wilcoxon(): + adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) + _make_nonnegative(adata) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(ValueError, match="only supported for method='wilcoxon'"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="t-test", + use_raw=False, + return_u_values=True, + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("tie_correct", [True, False]) @pytest.mark.parametrize("sparse", [True, False]) @@ -18,6 +308,7 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars """Test wilcoxon matches scanpy output across configurations.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + _make_nonnegative(adata_gpu) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") if sparse: @@ -55,11 +346,13 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] + rtol = 1e-13 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) - np.testing.assert_allclose(gpu_values, cpu_values, rtol=1e-13, atol=1e-15) + atol = 1e-15 + np.testing.assert_allclose(gpu_values, cpu_values, rtol=rtol, atol=atol) params = gpu_result["params"] assert params["use_raw"] is False @@ -69,11 +362,46 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars assert params["reference"] == reference +def test_rank_genes_groups_wilcoxon_dense_ovr_ties_match_scanpy(): + rng = np.random.default_rng(16) + X = rng.integers(0, 40, size=(128, 7)).astype(np.float32) + labels = rng.integers(0, 7, size=128).astype(str) + adata_gpu = sc.AnnData( + X=X.copy(), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + adata_cpu = adata_gpu.copy() + + kw = { + "groupby": "group", + "method": "wilcoxon", + "reference": "rest", + "use_raw": False, + "tie_correct": True, + "n_genes": adata_gpu.n_vars, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for group in gpu_result["scores"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + np.testing.assert_allclose( + gpu_result["scores"][group], cpu_result["scores"][group], rtol=1e-13 + ) + np.testing.assert_allclose( + gpu_result["pvals"][group], cpu_result["pvals"][group], rtol=1e-13 + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) def test_rank_genes_groups_wilcoxon_honors_layer_and_use_raw(reference): """Test that layer parameter is respected.""" np.random.seed(42) base = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=150) + _make_nonnegative(base) base.obs["blobs"] = base.obs["blobs"].astype("category") base.layers["signal"] = base.X.copy() @@ -121,6 +449,7 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): """Test group subsetting and bonferroni correction.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=4, n_observations=150) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") groups = ["0", "1", "2"] if reference != "rest" else ["0", "2"] @@ -148,6 +477,505 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): assert np.all(adjusted <= 1.0) +def test_rank_genes_groups_wilcoxon_skip_empty_groups_filters_singletons(): + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=21) + _make_nonnegative(adata) + adata.obs["target"] = pd.Categorical( + ["ref"] * 10 + ["valid"] * 10 + ["singleton"], + categories=["ref", "valid", "singleton", "empty"], + ) + + rsc.tl.rank_genes_groups( + adata, + "target", + method="wilcoxon", + reference="ref", + use_raw=False, + n_genes=3, + skip_empty_groups=True, + ) + + result = adata.uns["rank_genes_groups"] + assert result["names"].dtype.names == ("valid",) + assert result["scores"].dtype.names == ("valid",) + + +def test_rank_genes_groups_wilcoxon_skip_empty_groups_all_tests_filtered(): + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=11) + _make_nonnegative(adata) + adata.obs["target"] = pd.Categorical( + ["ref"] * 10 + ["singleton"], + categories=["ref", "singleton", "empty"], + ) + + rsc.tl.rank_genes_groups( + adata, + "target", + method="wilcoxon", + reference="ref", + use_raw=False, + skip_empty_groups=True, + ) + + result = adata.uns["rank_genes_groups"] + assert "names" not in result + assert result["params"]["reference"] == "ref" + + +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csr", id="host_csr"), + pytest.param("scipy_csc", id="host_csc"), + pytest.param("cupy_dense", id="device_dense"), + ], +) +def test_wilcoxon_subset_rest_stats_match_scanpy(fmt): + """groups=... with reference='rest' must use all other cells for stats.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=160) + _make_nonnegative(adata_gpu) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "pts": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + rtol = 1e-13 + atol = 1e-15 + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + for key in ("pts", "pts_rest"): + gpu_pts = gpu_result[key] + cpu_pts = cpu_result[key] + for col in gpu_pts.columns: + np.testing.assert_allclose( + gpu_pts[col].values, cpu_pts[col].values, rtol=1e-13, atol=1e-15 + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +def test_wilcoxon_zero_nnz_host_sparse_does_not_crash(reference, fmt): + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["0"] * 4 + ["1"] * 4 + ["2"] * 4, + categories=["0", "1", "2"], + ) + } + ) + adata = sc.AnnData( + X=_to_format(np.zeros((12, 5), dtype=np.float32), fmt), + obs=obs, + var=pd.DataFrame(index=[f"g{i}" for i in range(5)]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + pts=True, + ) + + result = adata.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + assert np.all(np.isfinite(np.asarray(result[field][group], dtype=float))) + + +def test_wilcoxon_ovo_host_csr_unsorted_indices_match_sorted(): + rng = np.random.default_rng(42) + dense = rng.poisson(1.0, size=(80, 12)).astype(np.float32) + dense[rng.random(dense.shape) < 0.55] = 0 + sorted_csr = sp.csr_matrix(dense) + unsorted_csr = sorted_csr.copy() + for row in range(unsorted_csr.shape[0]): + start, stop = unsorted_csr.indptr[row : row + 2] + order = np.arange(stop - start)[::-1] + unsorted_csr.indices[start:stop] = unsorted_csr.indices[start:stop][order] + unsorted_csr.data[start:stop] = unsorted_csr.data[start:stop][order] + unsorted_csr.has_sorted_indices = False + + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["ref"] * 20 + ["a"] * 20 + ["b"] * 20 + ["c"] * 20, + categories=["ref", "a", "b", "c"], + ) + } + ) + var = pd.DataFrame(index=[f"g{i}" for i in range(dense.shape[1])]) + sorted_adata = sc.AnnData(X=sorted_csr, obs=obs.copy(), var=var.copy()) + unsorted_adata = sc.AnnData(X=unsorted_csr, obs=obs.copy(), var=var.copy()) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "reference": "ref", + "use_raw": False, + "tie_correct": True, + "n_genes": dense.shape[1], + } + rsc.tl.rank_genes_groups(sorted_adata, **kw) + rsc.tl.rank_genes_groups(unsorted_adata, **kw) + + sorted_result = sorted_adata.uns["rank_genes_groups"] + unsorted_result = unsorted_adata.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + for group in sorted_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(unsorted_result[field][group], dtype=float), + np.asarray(sorted_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize( + "fmt", + [ + "numpy_dense", + "scipy_csr", + "scipy_csc", + "cupy_dense", + "cupy_csr", + "cupy_csc", + ], +) +def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt): + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=120) + _make_nonnegative(adata_gpu) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + "n_genes": 5, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + rtol = 1e-13 + atol = 1e-15 + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + +def _make_sized_groups_adata(group_sizes, n_genes, seed=0): + """AnnData with exact per-group sizes (drives OVO tier selection by max size).""" + rng = np.random.default_rng(seed) + n_obs = int(sum(group_sizes)) + X = np.abs(rng.standard_normal((n_obs, n_genes))).astype(np.float32) + X[X < 0.3] = 0.0 # zeros create tie groups, exercising tie correction + labels = np.concatenate( + [np.full(sz, f"g{i}", dtype=object) for i, sz in enumerate(group_sizes)] + ) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"gene_{j}" for j in range(n_genes)]) + adata = sc.AnnData(X=X, obs=obs, var=var) + adata.uns["log1p"] = {"base": None} + return adata + + +# OVO tier coverage: standard blobs hit only MEDIUM. +# These cases force LARGE fused-smem sort and HUGE CUB segmented sort. +@pytest.mark.parametrize( + "fmt", + ["numpy_dense", "cupy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"], +) +@pytest.mark.parametrize("tie_correct", [False, True]) +@pytest.mark.parametrize("big", [700, 3000], ids=["large_fused", "huge_cub"]) +def test_wilcoxon_ovo_large_group_tiers_match_scanpy(fmt, tie_correct, big): + # g0 = reference, g1 = the large test group that drives tier selection. + adata_gpu = _make_sized_groups_adata([60, big, 45], n_genes=6, seed=1) + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "use_raw": False, + "reference": "g0", + "tie_correct": tie_correct, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu = adata_gpu.uns["rank_genes_groups"] + cpu = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in gpu[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu[field][group], dtype=float), + np.asarray(cpu[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +# Many groups force global-memory accumulators, matching perturbation-scale DE. +# scanpy is too slow here, so this guards cross-format agreement at gmem scale. +@pytest.mark.parametrize("tie_correct", [False, True]) +def test_wilcoxon_ovr_many_groups_gmem_formats_agree(tie_correct): + adata = _make_sized_groups_adata([26] * 3100, n_genes=6, seed=3) + ref = None + for fmt in ("numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"): + a = adata.copy() + a.X = _to_format(adata.X, fmt) + rsc.tl.rank_genes_groups( + a, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=tie_correct, + n_genes=6, + ) + r = a.uns["rank_genes_groups"] + cur = { + field: np.vstack( + [np.asarray(r[field][n], dtype=float) for n in r[field].dtype.names] + ) + for field in ("scores", "pvals") + } + if ref is None: + ref = cur + continue + for field in ("scores", "pvals"): + np.testing.assert_allclose( + cur[field], ref[field], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + +# Host-dense OVR gmem buffers are reused round-robin and must be zeroed per batch. +# This forces enough groups and genes to wrap per-stream rank-sum buffers. +@pytest.mark.filterwarnings("ignore::RuntimeWarning") # 6200 tiny groups warn +def test_wilcoxon_ovr_dense_gmem_host_streaming_buffer_reuse(): + adata = _make_sized_groups_adata([2] * 6200, n_genes=400, seed=7) + ref = None + for fmt in ("cupy_dense", "numpy_dense", "cupy_csr"): + a = adata.copy() + a.X = _to_format(adata.X, fmt) + rsc.tl.rank_genes_groups( + a, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + ) + r = a.uns["rank_genes_groups"] + cur = { + field: np.vstack( + [np.asarray(r[field][n], dtype=float) for n in r[field].dtype.names] + ) + for field in ("scores", "pvals") + } + if ref is None: + ref = cur + continue + for field in ("scores", "pvals"): + np.testing.assert_allclose( + cur[field], ref[field], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + +# Host-dense OVR has only float32/float64 nanobind overloads. +# Other numpy numeric dtypes must cast to float32 rather than raise. +@pytest.mark.parametrize( + "data_dtype", [np.int32, np.int64, np.uint16, np.float16, bool] +) +def test_wilcoxon_dense_nonfloat_data_matches_float32(data_dtype): + rng = np.random.default_rng(5) + n_obs, n_genes = 120, 8 + counts = rng.integers(0, 5, size=(n_obs, n_genes)) + if data_dtype is bool: + counts = counts > 2 + typed = np.ascontiguousarray(counts.astype(data_dtype)) + f32 = np.ascontiguousarray(counts.astype(np.float32)) + labels = np.array([f"{i % 3}" for i in range(n_obs)]) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{j}" for j in range(n_genes)]) + + def run(arr): + adata = sc.AnnData(X=arr, obs=obs.copy(), var=var.copy()) + adata.uns["log1p"] = {"base": None} + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + n_genes=n_genes, + ) + return adata.uns["rank_genes_groups"] + + r_typed = run(typed) + r_f32 = run(f32) + for grp in r_typed["scores"].dtype.names: + np.testing.assert_array_equal( + np.asarray(r_typed["scores"][grp], dtype=float), + np.asarray(r_f32["scores"][grp], dtype=float), + ) + + +# F-contiguous host-dense numpy hits the F-order host-streaming overload. +# It must match the C-order run on identical data. +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_wilcoxon_ovr_fortran_order_host_dense_matches_c_order(dtype): + rng = np.random.default_rng(11) + X = np.abs(rng.standard_normal((300, 40))).astype(dtype) + X[X < 0.3] = 0.0 + labels = rng.integers(0, 5, 300) + obs = pd.DataFrame({"group": pd.Categorical([f"g{c}" for c in labels])}) + var = pd.DataFrame(index=[f"g{j}" for j in range(40)]) + + def run(arr): + adata = sc.AnnData(X=arr, obs=obs.copy(), var=var.copy()) + adata.uns["log1p"] = {"base": None} + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + ) + return adata.uns["rank_genes_groups"] + + xf = np.asfortranarray(X) + assert xf.flags.f_contiguous + r_f = run(xf) + r_c = run(np.ascontiguousarray(X)) + for field in ("scores", "pvals", "logfoldchanges"): + for grp in r_f[field].dtype.names: + np.testing.assert_array_equal( + np.asarray(r_f[field][grp], dtype=float), + np.asarray(r_c[field][grp], dtype=float), + ) + + +# Guards host sparse OVR smem packing for pts=True, where nnz offset once overran. +# n_groups=50 stays on smem but reaches the formerly faulting regime. +@pytest.mark.parametrize( + "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] +) +def test_wilcoxon_ovr_pts_many_groups_match_scanpy(fmt): + adata_gpu = _make_sized_groups_adata([40] * 50, n_genes=8, seed=4) + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "use_raw": False, + "reference": "rest", + "tie_correct": True, + "pts": True, + "n_genes": 8, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu = adata_gpu.uns["rank_genes_groups"] + cpu = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in gpu[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu[field][group], dtype=float), + np.asarray(cpu[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + gpu_pts, cpu_pts = gpu["pts"], cpu["pts"] + assert list(gpu_pts.columns) == list(cpu_pts.columns) + for col in gpu_pts.columns: + np.testing.assert_allclose( + gpu_pts[col].values, cpu_pts[col].values, rtol=1e-13, atol=1e-15 + ) + + +# Companion gmem-scale check with pts=True. +# It exercises global cast-accumulate and analytic-zero nnz paths. +def test_wilcoxon_ovr_many_groups_gmem_pts_formats_agree(): + adata = _make_sized_groups_adata([26] * 3100, n_genes=6, seed=5) + ref = None + for fmt in ("numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"): + a = adata.copy() + a.X = _to_format(adata.X, fmt) + rsc.tl.rank_genes_groups( + a, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + pts=True, + n_genes=6, + ) + r = a.uns["rank_genes_groups"] + cur = { + field: np.vstack( + [np.asarray(r[field][n], dtype=float) for n in r[field].dtype.names] + ) + for field in ("scores", "pvals") + } + cur["pts"] = r["pts"].values + if ref is None: + ref = cur + continue + for field in ("scores", "pvals", "pts"): + np.testing.assert_allclose( + cur[field], ref[field], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + @pytest.mark.parametrize( ("groups", "reference"), [ @@ -158,9 +986,8 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): ], ) @pytest.mark.parametrize("tie_correct", [False, True]) -@pytest.mark.parametrize("pre_load", [False, True]) def test_rank_genes_groups_wilcoxon_subset_matches_scanpy( - groups, reference, tie_correct, pre_load + groups, reference, tie_correct ): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=8, n_centers=5, n_observations=200) @@ -175,7 +1002,6 @@ def test_rank_genes_groups_wilcoxon_subset_matches_scanpy( reference=reference, use_raw=False, tie_correct=tie_correct, - pre_load=pre_load, ) sc.tl.rank_genes_groups( adata_cpu, @@ -221,6 +1047,7 @@ def test_rank_genes_groups_wilcoxon_with_renamed_categories( """Test with renamed category labels.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") # First run with original category names @@ -249,9 +1076,11 @@ def test_rank_genes_groups_wilcoxon_with_renamed_categories( @pytest.mark.parametrize("reference", ["rest", "1"]) def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): - """Test that group order doesn't affect results.""" + """Group order sets the output column order (matching scanpy); the per-group + statistics themselves are order-independent.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") bdata = adata.copy() @@ -265,9 +1094,13 @@ def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): bdata, "blobs", method="wilcoxon", groups=groups_reversed, reference=reference ) - expected_groups = {g for g in groups if g != reference} - assert set(adata.uns["rank_genes_groups"]["names"].dtype.names) == expected_groups - assert set(bdata.uns["rank_genes_groups"]["names"].dtype.names) == expected_groups + # Column order echoes the user-provided group order (reference excluded). + assert adata.uns["rank_genes_groups"]["names"].dtype.names == tuple( + g for g in groups if g != reference + ) + assert bdata.uns["rank_genes_groups"]["names"].dtype.names == tuple( + g for g in groups_reversed if g != reference + ) # Pick a group that's not the reference for comparison test_group = "3" if reference != "3" else "0" @@ -286,11 +1119,11 @@ def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): @pytest.mark.parametrize("reference", ["rest", "1"]) -@pytest.mark.parametrize("pre_load", [True, False]) -def test_rank_genes_groups_wilcoxon_pts(reference, pre_load): +def test_rank_genes_groups_wilcoxon_pts(reference): """Test that pts (fraction of cells expressing) is computed correctly.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + _make_nonnegative(adata_gpu) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") adata_cpu = adata_gpu.copy() @@ -303,7 +1136,6 @@ def test_rank_genes_groups_wilcoxon_pts(reference, pre_load): pts=True, tie_correct=False, reference=reference, - pre_load=pre_load, ) sc.tl.rank_genes_groups( adata_cpu, @@ -350,9 +1182,7 @@ def test_rank_genes_groups_wilcoxon_pts(reference, pre_load): ) -# ============================================================================ -# Ground-truth validation against scipy.stats.mannwhitneyu -# ============================================================================ +# Ground-truth validation against scipy.stats.mannwhitneyu. def _make_perturbation_adata( @@ -506,186 +1336,788 @@ def test_sparse_matches_dense(self, perturbation_adata, sparse): ) -# ============================================================================ -# Tests for ranking and tie correction kernels (edge cases from scipy) -# ============================================================================ +def _make_count_adata(seed=0, n_obs=120, n_genes=6, n_groups=3): + # Integer-valued counts as float64: float32-exact, zeros create ties. + rng = np.random.default_rng(seed) + X = rng.integers(0, 8, size=(n_obs, n_genes)).astype(np.float64) + X[X < 2] = 0.0 # extra zeros -> implicit-zero tie blocks + labels = np.array([f"{i % n_groups}" for i in range(n_obs)]) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{j}" for j in range(n_genes)]) + adata = sc.AnnData(X=X, obs=obs, var=var) + adata.uns["log1p"] = {"base": None} + return adata + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_wilcoxon_host_sparse_float64_data_matches_scanpy(fmt, reference): + # float64 host-sparse data exercises the *_f64 kernel bindings. + adata = _make_count_adata(seed=3) + adata_cpu = adata.copy() + mat = sp.csr_matrix(adata.X) if fmt == "scipy_csr" else sp.csc_matrix(adata.X) + assert mat.dtype == np.float64 + adata.X = mat + + kw = { + "groupby": "group", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + "n_genes": adata.n_vars, + } + rsc.tl.rank_genes_groups(adata, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + g = adata.uns["rank_genes_groups"] + c = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "pvals", "pvals_adj"): + for grp in g[field].dtype.names: + np.testing.assert_allclose( + np.asarray(g[field][grp], dtype=float), + np.asarray(c[field][grp], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) -class TestRankingKernel: - """Tests for _average_ranks based on scipy.stats.rankdata edge cases.""" - - @pytest.fixture - def average_ranks(self): - """Import the ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - ) - return _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +@pytest.mark.parametrize("data_dtype", [np.int32, np.int64, np.uint16, bool]) +def test_wilcoxon_sparse_integer_bool_data_matches_float32(fmt, data_dtype): + # Integer/bool data hits the cast-to-float32 branch; must match float32. + rng = np.random.default_rng(5) + n_obs, n_genes = 100, 6 + counts = rng.integers(0, 5, size=(n_obs, n_genes)) + if data_dtype is bool: + counts = counts > 2 + typed = counts.astype(data_dtype) + f32 = counts.astype(np.float32) + labels = np.array([f"{i % 3}" for i in range(n_obs)]) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{j}" for j in range(n_genes)]) + + def run(arr): + adata = sc.AnnData(X=_to_format(arr, fmt), obs=obs.copy(), var=var.copy()) + adata.uns["log1p"] = {"base": None} + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + tie_correct=True, + n_genes=n_genes, + ) + return adata.uns["rank_genes_groups"] - def test_basic_ranking(self, average_ranks): - """Test basic average ranking on simple data.""" - values = [3.0, 1.0, 2.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) + r_typed = run(typed) + r_f32 = run(f32) + for grp in r_typed["scores"].dtype.names: + np.testing.assert_array_equal( + np.asarray(r_typed["scores"][grp], dtype=float), + np.asarray(r_f32["scores"][grp], dtype=float), + ) - def test_all_ties(self, average_ranks): - """All identical values should get the average rank.""" - values = [5.0, 5.0, 5.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - def test_no_ties(self, average_ranks): - """All unique values should get sequential ranks.""" - values = [1.0, 2.0, 3.0, 4.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) +def test_wilcoxon_device_sparse_bool_data_raises(): + counts = np.arange(400).reshape(100, 4) % 3 == 0 + mat = cpsp.csr_matrix(cp.asarray(counts)) + adata = sc.AnnData( + X=mat, + obs=pd.DataFrame({"group": pd.Categorical([f"{i % 2}" for i in range(100)])}), + var=pd.DataFrame(index=[f"g{j}" for j in range(4)]), + ) + with pytest.raises(TypeError, match="float32 or float64"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) - def test_mixed_ties(self, average_ranks): - """Mix of ties and unique values.""" - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - def test_negative_values(self, average_ranks): - """Test with negative values.""" - values = [-3.0, -1.0, -2.0, 0.0, 1.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_wilcoxon_sparse_float16_data_raises(fmt): + # Unsupported float16 sparse data (host + device) is rejected with TypeError. + rng = np.random.default_rng(0) + dense = np.abs(rng.standard_normal((40, 4))).astype(np.float32) + mat = _to_format(dense, fmt) + xp = cp if fmt.startswith("cupy") else np + mat.data = mat.data.astype(xp.float16) + adata = sc.AnnData( + X=mat, + obs=pd.DataFrame({"group": pd.Categorical([f"{i % 2}" for i in range(40)])}), + var=pd.DataFrame(index=[f"g{j}" for j in range(4)]), + ) + with pytest.raises(TypeError, match="float32"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_wilcoxon_sparse_complex_data_raises(fmt): + rng = np.random.default_rng(4) + dense = np.abs(rng.standard_normal((40, 4))).astype(np.float32) + dense[dense < 0.4] = 0.0 + mat = _to_format(dense.astype(np.complex64), fmt) + adata = sc.AnnData( + X=mat, + obs=pd.DataFrame({"group": pd.Categorical([f"{i % 2}" for i in range(40)])}), + var=pd.DataFrame(index=[f"g{j}" for j in range(4)]), + ) + with pytest.raises(TypeError, match="complex sparse data is not supported"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) - def test_single_element(self, average_ranks): - """Single element should have rank 1.""" - values = [42.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.0]) - def test_two_elements_tied(self, average_ranks): - """Two tied elements should both have rank 1.5.""" - values = [7.0, 7.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.5, 1.5]) +@pytest.mark.parametrize("reference", ["rest", "2"]) +def test_wilcoxon_group_subset_column_order_matches_scanpy(reference): + """Output column order must echo the user's ``groups=`` list (scanpy parity), + not be re-sorted to category order.""" + np.random.seed(0) + adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) + bdata = adata.copy() - def test_multiple_columns(self, average_ranks): - """Test ranking across multiple columns independently.""" - col0 = [3.0, 1.0, 2.0] - col1 = [1.0, 1.0, 2.0] - data = np.column_stack([col0, col1]).astype(np.float64) - result = average_ranks(cp.asarray(data, order="F")) + # Deliberately out-of-category-order subset. + groups = ["3", "1"] if reference != "rest" else ["3", "1", "0"] + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + groups=groups, + reference=reference, + ) + sc.tl.rank_genes_groups( + bdata, + "blobs", + method="wilcoxon", + use_raw=False, + groups=groups, + reference=reference, + ) + assert ( + adata.uns["rank_genes_groups"]["names"].dtype.names + == bdata.uns["rank_genes_groups"]["names"].dtype.names + ) - np.testing.assert_allclose(result.get()[:, 0], rankdata(col0, method="average")) - np.testing.assert_allclose(result.get()[:, 1], rankdata(col1, method="average")) +def test_wilcoxon_host_sparse_negative_chunked_stats_match_scanpy(): + """Host sparse negatives fallback must match scanpy stats across chunks.""" + rng = np.random.default_rng(0) + n_obs, n_vars = 200, 24 + X = (rng.random((n_obs, n_vars)) * 5.0).astype(np.float64) + X[X < 1.5] = 0.0 # structural zeros so pts < 1 + X[rng.random((n_obs, n_vars)) < 0.01] = -0.5 # a few negatives -> fallback + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(n_obs)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_vars)]) -class TestTieCorrectionKernel: - """Tests for _tie_correction based on scipy.stats.tiecorrect edge cases.""" + gpu = sc.AnnData(X=sp.csr_matrix(X), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=X.copy(), obs=obs.copy(), var=var.copy()) - @pytest.fixture - def tie_correction(self): - """Import the tie correction function and ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - _tie_correction, + rsc.tl.rank_genes_groups( + gpu, + "group", + method="wilcoxon", + use_raw=False, + reference="rest", + pts=True, + n_genes=n_vars, + chunk_size=8, # < n_vars -> multiple chunks + ) + sc.tl.rank_genes_groups( + cpu, "group", method="wilcoxon", use_raw=False, reference="rest", pts=True + ) + g = gpu.uns["rank_genes_groups"] + c = cpu.uns["rank_genes_groups"] + for group in g["names"].dtype.names: + g_lfc = dict( + zip(g["names"][group], np.asarray(g["logfoldchanges"][group], float)) ) + c_lfc = dict( + zip(c["names"][group], np.asarray(c["logfoldchanges"][group], float)) + ) + for gene, val in g_lfc.items(): + np.testing.assert_allclose( + val, c_lfc[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + for frame in ("pts", "pts_rest"): + for col in c[frame].columns: + np.testing.assert_allclose( + g[frame].loc[c[frame].index, col].values, + c[frame][col].values, + rtol=1e-12, + atol=1e-13, + ) - return _tie_correction, _average_ranks - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_no_ties(self, tie_correction): - """No ties should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 3.0, 4.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_all_ties(self, tie_correction): - """All tied values should give correction factor 0.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [5.0, 5.0, 5.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_mixed_ties(self, tie_correction): - """Mix of ties should give intermediate correction factor.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_two_elements_tied(self, tie_correction): - """Two tied elements.""" - _tie_correction, _average_ranks = tie_correction - - values = [7.0, 7.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) +def test_wilcoxon_fdr_ties_nan_match_scanpy(): + """BH FDR must match scanpy on tied, constant, and all-zero genes.""" + rng = np.random.default_rng(1) + n_obs, n_vars = 240, 30 + X = rng.integers(0, 3, size=(n_obs, n_vars)).astype(np.float64) # heavy ties + X[:, 0] = 1.0 # constant gene -> identical p across groups + X[:, 1] = 0.0 # all-zero gene + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(n_obs)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_vars)]) - def test_single_element(self, tie_correction): - """Single element should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction + gpu = sc.AnnData(X=cp.asarray(X), obs=obs.copy(), var=var.copy()) # GPU FDR path + cpu = sc.AnnData(X=X.copy(), obs=obs.copy(), var=var.copy()) - values = [42.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + rsc.tl.rank_genes_groups( + gpu, "group", method="wilcoxon", use_raw=False, tie_correct=True + ) + sc.tl.rank_genes_groups( + cpu, "group", method="wilcoxon", use_raw=False, tie_correct=True + ) + g = gpu.uns["rank_genes_groups"] + c = cpu.uns["rank_genes_groups"] + for group in g["names"].dtype.names: + g_adj = dict(zip(g["names"][group], np.asarray(g["pvals_adj"][group], float))) + c_adj = dict(zip(c["names"][group], np.asarray(c["pvals_adj"][group], float))) + for gene, val in g_adj.items(): + np.testing.assert_allclose( + val, c_adj[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) - # Single element: n^3 - n = 0, so formula gives 1.0 - np.testing.assert_allclose(result.get()[0], 1.0, rtol=1e-10) - def test_multiple_columns(self, tie_correction): - """Test tie correction across multiple columns independently.""" - _tie_correction, _average_ranks = tie_correction +def _promote_host_index_dtype(X): + """Copy a host scipy CSR/CSC matrix with promoted index-array dtypes.""" + X = X.copy() + X.indptr = X.indptr.astype(np.int64) + X.indices = X.indices.astype(np.int64) + return X - col0 = [1.0, 2.0, 3.0] # No ties - col1 = [5.0, 5.0, 5.0] # All ties - data = np.column_stack([col0, col1]).astype(np.float64) - _, sorted_vals = _average_ranks(cp.asarray(data, order="F"), return_sorted=True) - result = _tie_correction(sorted_vals) - np.testing.assert_allclose( - result.get()[0], tiecorrect(rankdata(col0)), rtol=1e-10 +@pytest.mark.parametrize("reference", ["rest", "1"]) # OVR vs OVO host paths +@pytest.mark.parametrize( + ("layout", "data_dtype"), + [ + ("csr", np.float32), + ("csr", np.float64), + ("csc", np.float32), + ("csc", np.float64), + ], +) +def test_host_sparse_int64_templates_match_int32(reference, layout, data_dtype): + """Host sparse int64 index templates must match the int32 path bit-for-bit.""" + rng = np.random.default_rng(0) + dense = (rng.random((150, 8)) * 4.0).astype(np.float64) + dense[dense < 1.5] = 0.0 # nonnegative + structural zeros -> sparse fast path + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(150)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(8)]) + + maker = sp.csr_matrix if layout == "csr" else sp.csc_matrix + base = maker(dense.astype(data_dtype)) + + a32 = sc.AnnData(X=base.copy(), obs=obs.copy(), var=var.copy()) + a64 = sc.AnnData( + X=_promote_host_index_dtype(base), + obs=obs.copy(), + var=var.copy(), + ) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(a32, "group", **kw) + rsc.tl.rank_genes_groups(a64, "group", **kw) + + r32, r64 = a32.uns["rank_genes_groups"], a64.uns["rank_genes_groups"] + assert r64["names"].dtype.names == r32["names"].dtype.names + for fld in ("scores", "pvals", "pvals_adj", "logfoldchanges"): + for grp in r32[fld].dtype.names: + np.testing.assert_array_equal( + np.asarray(r64[fld][grp]), np.asarray(r32[fld][grp]) + ) + + +def _anndata_with_group_sizes(sizes, *, n_genes=6, seed=0): + """Dense AnnData with exact per-group sizes for OVO tier tests.""" + rng = np.random.default_rng(seed) + labels = [] + for name, n in sizes.items(): + labels += [name] * n + X = rng.integers(0, 6, size=(len(labels), n_genes)).astype(np.float64) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_genes)]) + return sc.AnnData(X=X, obs=obs, var=var) + + +def _assert_ovo_matches_scanpy(adata, reference): + bdata = adata.copy() + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata, "group", **kw) + sc.tl.rank_genes_groups(bdata, "group", **kw) + g, c = adata.uns["rank_genes_groups"], bdata.uns["rank_genes_groups"] + for fld in ("scores", "pvals", "pvals_adj"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +@pytest.mark.parametrize( + ("sizes", "seed"), + [ + ({"ref": 40, "g20": 20, "g50": 50, "g300": 300, "g1000": 1000}, 1), + ({"ref": 40, "huge": 3000}, 2), + ], +) +def test_ovo_tier_bands_match_scanpy(sizes, seed): + """OVO dense-tiered MEDIUM/LARGE/HUGE paths must match scanpy.""" + adata = _anndata_with_group_sizes(sizes, seed=seed) + _assert_ovo_matches_scanpy(adata, reference="ref") + + +@pytest.mark.filterwarnings("ignore::RuntimeWarning") # 6200 tiny groups warn +def test_ovr_dense_gmem_branch_matches_scipy(): + """Dense OVR gmem branch must match scipy on sampled groups.""" + from scipy.stats import mannwhitneyu + + n_groups, n_genes = 6200, 4 # > 6112 -> dense gmem accumulator + rng = np.random.default_rng(3) + labels = np.repeat(np.arange(n_groups), 2) # 2 cells per group + X = rng.integers(0, 6, size=(labels.size, n_genes)).astype(np.float64) + obs = pd.DataFrame({"group": pd.Categorical([str(x) for x in labels])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_genes)]) + adata = sc.AnnData(X=X, obs=obs, var=var) + + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", use_raw=False, tie_correct=True + ) + res = adata.uns["rank_genes_groups"] + for grp in ("0", "1", "250", "1000", "3057", "6112", "6199"): + gp = dict(zip(res["names"][grp], np.asarray(res["pvals"][grp], float))) + mask = labels == int(grp) + for gi, gene in enumerate(var.index): + _, p = mannwhitneyu( + X[mask, gi], + X[~mask, gi], + use_continuity=False, + alternative="two-sided", + method="asymptotic", + ) + np.testing.assert_allclose( + gp[gene], + p, + rtol=1e-10, + atol=1e-12, + equal_nan=True, + err_msg=f"group {grp} gene {gene}", + ) + + +def test_skip_empty_groups_vs_rest_drops_singleton(): + """skip_empty_groups=True with reference='rest' drops singleton groups.""" + adata = _anndata_with_group_sizes({"a": 10, "b": 10, "c": 1}, seed=4) + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", use_raw=False, skip_empty_groups=True + ) + names = set(adata.uns["rank_genes_groups"]["names"].dtype.names) + assert names == {"a", "b"} # singleton "c" dropped, no error + + +def test_skip_empty_groups_reference_too_small_raises(): + """skip_empty_groups=True with a <2-cell reference raises a clear error.""" + adata = _anndata_with_group_sizes({"a": 10, "b": 10, "c": 1}, seed=4) + with pytest.raises(ValueError, match="reference = c has fewer than two samples"): + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference="c", + skip_empty_groups=True, + ) + + +def test_skip_empty_groups_none_remain_raises(): + """skip_empty_groups=True raises when no group has >=2 cells (vs-rest).""" + adata = _anndata_with_group_sizes({"a": 1, "b": 1, "c": 1}, seed=4) + with pytest.raises(ValueError, match="No groups with at least two samples remain"): + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", use_raw=False, skip_empty_groups=True ) + + +@pytest.mark.parametrize( + "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] +) +def test_ovr_tie_correct_false_tie_heavy_matches_scanpy(fmt): + """OVR tie_correct=False on tie-heavy data must match scanpy for all formats.""" + rng = np.random.default_rng(7) + n_obs, n_genes = 180, 8 + dense = rng.integers(0, 5, size=(n_obs, n_genes)).astype(np.float64) # ties + dense[dense < 1.0] = 0.0 # nonnegative + structural zeros -> sparse fast path + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(n_obs)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_genes)]) + + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": "rest", + "tie_correct": False, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for fld in ("scores", "pvals"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +@pytest.mark.parametrize( + "fmt", ["numpy_dense", "scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"] +) +@pytest.mark.parametrize("reference", ["rest", "1"]) # OVR and OVO epilogues +def test_use_continuity_matches_scipy(fmt, reference): + """Continuity epilogues must match scipy across OVR/OVO and formats.""" + from scipy.stats import mannwhitneyu + + rng = np.random.default_rng(8) + n_obs, n_genes = 150, 6 + # Overlapping groups (same distribution) -> moderate |R-E[R]| -> continuity + # is material. Integer values give ties (exercises the tie term too). + dense = rng.integers(0, 4, size=(n_obs, n_genes)).astype(np.float64) + labels = np.array([str(i % 3) for i in range(n_obs)]) + obs = pd.DataFrame({"group": pd.Categorical(labels)}) + var = pd.DataFrame(index=[f"g{i}" for i in range(n_genes)]) + + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + rsc.tl.rank_genes_groups( + gpu, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + tie_correct=True, + use_continuity=True, + n_genes=n_genes, + ) + res = gpu.uns["rank_genes_groups"] + for grp in res["names"].dtype.names: + gm = dict(zip(res["names"][grp], np.asarray(res["pvals"][grp], float))) + mask_g = labels == grp + mask_r = (labels != grp) if reference == "rest" else (labels == reference) + for gi, gene in enumerate(var.index): + _, p = mannwhitneyu( + dense[mask_g, gi], + dense[mask_r, gi], + use_continuity=True, + alternative="two-sided", + method="asymptotic", + ) + np.testing.assert_allclose( + gm[gene], p, rtol=1e-10, atol=1e-12, equal_nan=True + ) + + +# Entry-point / init validation (rank_genes_groups + _RankGenes + _select_groups). + + +def test_rank_genes_groups_default_method_is_ttest(): + """Omitting method= defaults to t-test (rank_genes_groups path).""" + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + rsc.tl.rank_genes_groups(adata, "group", use_raw=False) + assert adata.uns["rank_genes_groups"]["params"]["method"] == "t-test" + + +@pytest.mark.parametrize( + ("override", "exc", "match"), + [ + ({"method": "nope"}, ValueError, "method must be one of"), + ({"corr_method": "foo"}, ValueError, "corr_method must be either"), + ({"chunk_size": 0}, ValueError, "chunk_size must be a positive integer"), + ({"chunk_size": -4}, ValueError, "chunk_size must be a positive integer"), + ({"groups": "0"}, ValueError, "Specify a sequence of groups"), + ({"reference": "ZZ"}, ValueError, "needs to be one of groupby"), + ], +) +def test_rank_genes_groups_invalid_args_raise(override, exc, match): + """Public-API argument validation raises (covers __init__/_core guards).""" + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + kwargs = {"method": "wilcoxon", "use_raw": False, **override} + with pytest.raises(exc, match=match): + rsc.tl.rank_genes_groups(adata, "group", **kwargs) + + +def test_rank_genes_groups_mask_var_missing_key_raises(): + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + with pytest.raises(KeyError, match="not found in adata.var"): + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", use_raw=False, mask_var="nope" + ) + + +def test_rank_genes_groups_mask_var_wrong_shape_raises(): + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + with pytest.raises(ValueError, match="mask_var has wrong shape"): + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + mask_var=np.ones(adata.n_vars + 3, dtype=bool), + ) + + +def test_rank_genes_groups_layer_and_use_raw_conflict_raises(): + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + adata.layers["L"] = adata.X.copy() + with pytest.raises(ValueError, match="Cannot specify .layer. and have"): + rsc.tl.rank_genes_groups( + adata, "group", method="wilcoxon", layer="L", use_raw=True + ) + + +def test_rank_genes_groups_use_raw_without_raw_raises(): + adata = _anndata_with_group_sizes({"0": 10, "1": 10}, seed=5) + with pytest.raises(ValueError, match="is empty"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=True) + + +def test_singleton_group_without_skip_raises(): + """Non-skip path: a <2-cell group raises in _select_groups (line 131-135).""" + adata = _anndata_with_group_sizes({"a": 10, "b": 10, "c": 1}, seed=5) + with pytest.raises(ValueError, match="fewer than two samples"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +@pytest.mark.parametrize("use_raw", [None, True]) +def test_rank_genes_groups_reads_raw_matches_scanpy(use_raw): + """use_raw=None and use_raw=True both read adata.raw, matching scanpy.""" + adata = _anndata_with_group_sizes({"0": 30, "1": 30, "2": 30}, seed=6) + adata.raw = adata.copy() # raw holds the real signal + rng = np.random.default_rng(99) + adata.X = rng.integers(0, 6, size=adata.shape).astype(np.float64) # noise in .X + bdata = adata.copy() + kw = {"method": "wilcoxon", "use_raw": use_raw, "tie_correct": True} + rsc.tl.rank_genes_groups(adata, "group", **kw) + sc.tl.rank_genes_groups(bdata, "group", **kw) + g, c = adata.uns["rank_genes_groups"], bdata.uns["rank_genes_groups"] + for grp in g["scores"].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g["scores"][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c["scores"][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) # OVR (_core) + OVO (_wilcoxon) +@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr"]) +def test_log1p_base_logfoldchanges_match_scanpy(reference, fmt): + """A non-default log1p base changes expm1 in the logfoldchange computation + (_core.py:115 + the OVO host-sparse fast path _wilcoxon.py:232-234).""" + rng = np.random.default_rng(7) + dense = rng.integers(1, 6, size=(120, 6)).astype(np.float64) # nonneg, finite lfc + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(120)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + gpu.uns["log1p"] = {"base": 2.0} + cpu.uns["log1p"] = {"base": 2.0} + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for grp in g["logfoldchanges"].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g["logfoldchanges"][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c["logfoldchanges"][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-6, atol=1e-6, equal_nan=True + ) + + +# OVO / OVR parity and dispatch gaps. + + +def test_ovo_dense_fallback_pts_match_scanpy(): + """OVO sparse-negative dense fallback pts must match scanpy.""" + rng = np.random.default_rng(11) + dense = (rng.random((120, 8)) * 5.0).astype(np.float64) + dense[dense < 1.5] = 0.0 + dense[rng.random(dense.shape) < 0.01] = -0.5 # negatives -> dense fallback + obs = pd.DataFrame( + {"group": pd.Categorical(["a" if i % 2 else "b" for i in range(120)])} + ) + var = pd.DataFrame(index=[f"g{i}" for i in range(8)]) + gpu = sc.AnnData(X=sp.csr_matrix(dense), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": "b", + "pts": True, + "n_genes": 8, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for col in c["pts"].columns: np.testing.assert_allclose( - result.get()[1], tiecorrect(rankdata(col1)), rtol=1e-10 + g["pts"].loc[c["pts"].index, col].values, + c["pts"][col].values, + rtol=1e-12, + atol=1e-13, ) - def test_large_tie_groups(self, tie_correction): - """Test with large tie groups.""" - _tie_correction, _average_ranks = tie_correction - # 50 values of 1, 50 values of 2 (non-multiple of 32 to test warp handling) - values = [1.0] * 50 + [2.0] * 50 - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) +@pytest.mark.parametrize("fmt", ["numpy_dense", "cupy_csr"]) # CPU + GPU FDR epilogues +def test_bonferroni_matches_scanpy(fmt): + """Bonferroni correction must match scanpy, not just clamp below one.""" + rng = np.random.default_rng(12) + dense = rng.integers(0, 5, size=(150, 6)).astype(np.float64) + dense[dense < 1.0] = 0.0 + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(150)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": "1", + "corr_method": "bonferroni", + "tie_correct": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for fld in ("scores", "pvals", "pvals_adj"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-12, atol=1e-13, equal_nan=True + ) + + +def test_ovr_all_empty_csc_totals_runs(): + """All-zero host CSC + a groups= subset (leaves an unselected category) + + reference='rest' + pts=True exercises the empty-column totals branch.""" + dense = np.zeros((20, 5), dtype=np.float64) + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(20)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(5)]) + adata = sc.AnnData(X=sp.csc_matrix(dense), obs=obs, var=var) + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + groups=["0", "1"], + reference="rest", + pts=True, + ) + res = adata.uns["rank_genes_groups"] + for grp in res["scores"].dtype.names: + assert np.all(np.isfinite(np.asarray(res["scores"][grp], float))) + assert "pts_rest" in res + + +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_ovr_fully_dense_column_match_scanpy(fmt): + """A column with no structural zeros (nnz==n_rows) hits the total_zero==0 + branch of the sparse OVR accumulate kernel. Validate vs scanpy.""" + rng = np.random.default_rng(13) + dense = rng.integers(0, 5, size=(90, 4)).astype(np.float64) + dense[dense < 1.0] = 0.0 + dense[:, 0] = rng.integers(1, 6, size=90) # column 0 strictly positive -> no zeros + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(90)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(4)]) + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": "rest", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for fld in ("scores", "pvals"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-13, atol=1e-15, equal_nan=True + ) + - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) +@pytest.mark.parametrize("fmt", ["cupy_csr", "cupy_csc"]) +def test_ovr_device_sparse_subset_match_scanpy(fmt): + """Device-sparse OVR with a groups= subset exercises the sentinel-group skip + in the device sparse kernels. Validate vs scanpy on the dense copy.""" + rng = np.random.default_rng(14) + dense = rng.integers(0, 6, size=(160, 6)).astype(np.float64) + dense[dense < 1.0] = 0.0 + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 4}" for i in range(160)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) + gpu = sc.AnnData(X=_to_format(dense, fmt), obs=obs.copy(), var=var.copy()) + cpu = sc.AnnData(X=dense.copy(), obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(gpu, "group", **kw) + sc.tl.rank_genes_groups(cpu, "group", **kw) + g, c = gpu.uns["rank_genes_groups"], cpu.uns["rank_genes_groups"] + for fld in ("scores", "pvals"): + for grp in g[fld].dtype.names: + gm = dict(zip(g["names"][grp], np.asarray(g[fld][grp], float))) + cm = dict(zip(c["names"][grp], np.asarray(c[fld][grp], float))) + for gene, val in gm.items(): + np.testing.assert_allclose( + val, cm[gene], rtol=1e-13, atol=1e-15, equal_nan=True + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize("layout", ["csr", "csc"]) +def test_host_sparse_mismatched_index_dtype_raises(reference, layout): + """Host sparse indices/indptr must keep scipy's same-dtype invariant.""" + rng = np.random.default_rng(15) + dense = rng.integers(0, 5, size=(120, 6)).astype(np.float64) + dense[dense < 1.0] = 0.0 + obs = pd.DataFrame({"group": pd.Categorical([f"{i % 3}" for i in range(120)])}) + var = pd.DataFrame(index=[f"g{i}" for i in range(6)]) + maker = sp.csr_matrix if layout == "csr" else sp.csc_matrix + m64 = maker(dense) + m64.indices = m64.indices.astype(np.int64) # keep indptr int32 + assert m64.indptr.dtype == np.int32 + assert m64.indices.dtype == np.int64 + adata = sc.AnnData(X=m64, obs=obs.copy(), var=var.copy()) + kw = { + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + with pytest.raises(TypeError, match="indices and indptr must have the same dtype"): + rsc.tl.rank_genes_groups(adata, "group", **kw) diff --git a/tests/test_rank_genes_groups_wilcoxon_binned.py b/tests/test_rank_genes_groups_wilcoxon_binned.py index 85abc3e2..6922d7ce 100644 --- a/tests/test_rank_genes_groups_wilcoxon_binned.py +++ b/tests/test_rank_genes_groups_wilcoxon_binned.py @@ -428,21 +428,43 @@ def test_sparse_with_actual_zeros(self, adata_blobs): assert np.all(pvals >= 0) assert np.all(pvals <= 1) - def test_sparse_negative_values_raises(self, adata_blobs): - """Sparse input with negative values should raise ValueError.""" + def test_sparse_negative_values_fallback(self, adata_blobs): + """Sparse negatives must densify so implicit zeros rank correctly.""" import cupy as cp import cupyx.scipy.sparse as cpsp - adata = adata_blobs.copy() - rsc.get.anndata_to_GPU(adata) - # Make sparse with negative values - dense = cp.array(adata.X) - dense[:, 0] = -1.0 - adata.X = cpsp.csr_matrix(dense) + rng = np.random.default_rng(0) + n_obs, n_vars = adata_blobs.shape + base = (rng.random((n_obs, n_vars)) * 5.0).astype(np.float64) + base[base < 2.0] = 0.0 # real structural zeros (~40%) + # Negatives in zero-bearing cells: columns then hold structural zeros + # AND values below them, the case the fallback must rank correctly. + neg = (base == 0.0) & (rng.random(base.shape) < 0.05) + base[neg] = -0.5 + base[0, 1] = 10.0 # keep a positive max so sparse/dense ranges agree + assert (base == 0).any() and (base < 0).any() and (base > 0).any() + + dense = cp.asarray(base) + sparse_adata = adata_blobs.copy() + sparse_adata.X = cpsp.csr_matrix(dense) + dense_adata = adata_blobs.copy() + dense_adata.X = dense - with pytest.raises(ValueError, match="Sparse input contains negative values"): - rsc.tl.rank_genes_groups( - adata, "blobs", method="wilcoxon_binned", use_raw=False + rsc.tl.rank_genes_groups( + sparse_adata, "blobs", method="wilcoxon_binned", use_raw=False + ) + rsc.tl.rank_genes_groups( + dense_adata, "blobs", method="wilcoxon_binned", use_raw=False + ) + + sp_scores = sparse_adata.uns["rank_genes_groups"]["scores"] + dn_scores = dense_adata.uns["rank_genes_groups"]["scores"] + for group in sp_scores.dtype.names: + np.testing.assert_allclose( + np.asarray(sp_scores[group], dtype=float), + np.asarray(dn_scores[group], dtype=float), + rtol=1e-13, + atol=1e-13, ) def test_log1p_warning(self, adata_blobs): @@ -534,3 +556,122 @@ def test_top_genes_match_scipy(adata_blobs): scipy_top = set(adata_blobs.var_names[np.argsort(pvals)[:n_top]]) overlap = len(binned_top & scipy_top) assert overlap >= n_top - 1, f"Group {group}: {overlap}/{n_top} overlap" + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_binned_bin_exact_matches_scipy(reference): + """Bin-exact integer data must match scipy.mannwhitneyu.""" + import pandas as pd + from scipy.stats import mannwhitneyu + + rng = np.random.default_rng(20) + n_obs, n_genes = 150, 6 + X = rng.integers(0, 5, size=(n_obs, n_genes)).astype(np.float32) # bin-exact + labels = np.array([str(i % 3) for i in range(n_obs)]) + genes = [f"v{i}" for i in range(n_genes)] + + def run(tie_correct, use_continuity): + a = sc.AnnData( + X=X.copy(), + obs=pd.DataFrame({"g": pd.Categorical(labels)}), + var=pd.DataFrame(index=genes), + ) + a.uns["log1p"] = {"base": None} # silence the log-norm warning + rsc.get.anndata_to_GPU(a) + rsc.tl.rank_genes_groups( + a, + "g", + method="wilcoxon_binned", + use_raw=False, + reference=reference, + tie_correct=tie_correct, + use_continuity=use_continuity, + n_bins=1000, + bin_range="auto", + n_genes=n_genes, + ) + r = a.uns["rank_genes_groups"] + return { + grp: dict(zip(r["names"][grp], np.asarray(r["pvals"][grp], float))) + for grp in r["names"].dtype.names + } + + # correctness vs scipy (tie_correct=True; both continuity settings) + for use_continuity in (False, True): + pv = run(True, use_continuity) + for grp, gm in pv.items(): + mask_g = labels == grp + mask_r = (labels != grp) if reference == "rest" else (labels == reference) + for gi, v in enumerate(genes): + _, p = mannwhitneyu( + X[mask_g, gi], + X[mask_r, gi], + use_continuity=use_continuity, + alternative="two-sided", + method="asymptotic", + ) + np.testing.assert_allclose( + gm[v], p, rtol=1e-6, atol=1e-6, equal_nan=True + ) + + # non-vacuity self-guards: each flag must materially change the result + def differs(a, b): + return any( + not np.isclose(a[g][v], b[g][v], rtol=1e-9, atol=1e-12) + for g in a + for v in a[g] + ) + + assert differs(run(True, True), run(True, False)), "use_continuity inert (vacuous)" + assert differs(run(True, False), run(False, False)), "tie_correct inert (vacuous)" + + +def test_binned_all_zero_sparse_finite(adata_blobs): + """All-zero in-memory sparse input (nnz==0 _data_range guard): no crash, all + pvals finite and 1.0 (every value in one bin -> z=0).""" + import cupy as cp + import cupyx.scipy.sparse as cpsp + + adata = adata_blobs.copy() + adata.X = cpsp.csr_matrix(cp.zeros(adata.shape, dtype=cp.float32)) + rsc.tl.rank_genes_groups(adata, "blobs", method="wilcoxon_binned", use_raw=False) + res = adata.uns["rank_genes_groups"] + for grp in res["pvals"].dtype.names: + p = np.asarray(res["pvals"][grp], dtype=float) + assert np.all(np.isfinite(p)) and np.allclose(p, 1.0) + + +def test_binned_log1p_invalid_for_negative_sparse_coerces_to_auto(adata_blobs): + """Negative sparse log1p range must warn, coerce to auto, and match auto.""" + import cupy as cp + import cupyx.scipy.sparse as cpsp + + rng = np.random.default_rng(21) + n_obs, n_vars = adata_blobs.shape + base = (rng.random((n_obs, n_vars)) * 4.0).astype(np.float64) + base[base < 1.5] = 0.0 + neg = (base == 0.0) & (rng.random(base.shape) < 0.05) + base[neg] = -0.5 + base[0, 1] = 10.0 # positive max so sparse/dense ranges align + dense = cp.asarray(base) + + sp_ad = adata_blobs.copy() + sp_ad.X = cpsp.csr_matrix(dense) + auto_ad = adata_blobs.copy() + auto_ad.X = cpsp.csr_matrix(dense) + with pytest.warns(RuntimeWarning, match="bin_range='log1p' is invalid"): + rsc.tl.rank_genes_groups( + sp_ad, "blobs", method="wilcoxon_binned", use_raw=False, bin_range="log1p" + ) + rsc.tl.rank_genes_groups( + auto_ad, "blobs", method="wilcoxon_binned", use_raw=False, bin_range="auto" + ) + sp_s = sp_ad.uns["rank_genes_groups"]["scores"] + au_s = auto_ad.uns["rank_genes_groups"]["scores"] + for grp in sp_s.dtype.names: + np.testing.assert_allclose( + np.asarray(sp_s[grp], dtype=float), + np.asarray(au_s[grp], dtype=float), + rtol=1e-13, + atol=1e-13, + )