Skip to content

Commit 119007f

Browse files
authored
Use a kernel for embedding density (#590)
* first test * fix kernel to run like scipy * add release note
1 parent 533d4ea commit 119007f

7 files changed

Lines changed: 157 additions & 34 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ if (RSC_BUILD_EXTENSIONS)
8383
add_nb_cuda_module(_pv_cuda src/rapids_singlecell/_cuda/pv/pv.cu)
8484
add_nb_cuda_module(_edistance_cuda src/rapids_singlecell/_cuda/edistance/edistance.cu)
8585
add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu)
86+
add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu)
8687
add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu)
8788
# Harmony CUDA modules
8889
add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu)

docs/release-notes/0.15.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
```{rubric} Features
44
```
55
* Improves numerical accuracy and adds parameters to `tl.rank_genes_groups` Wilcoxon methods: uses ``erfc`` for p-values to avoid underflow, adds ``tie_correct`` and ``use_continuity`` to ``wilcoxon_binned``, and refactors ``Aggregate`` with a unified ``count_mean_var()`` dispatcher and raw ``sq_sum`` output for GPU-resident stats computation {pr}`585` {smaller}`S Dicks`
6+
* Replace cuML KDE in ``tl.embedding_density`` with a custom CUDA kernel using covariance-aware Gaussian KDE matching ``scipy.stats.gaussian_kde``, removing the cuML dependency and the ``batchsize`` parameter {pr}`590` {smaller}`S Dicks`
67

78
```{rubric} Removals
89
```

src/rapids_singlecell/_cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"_harmony_pen_cuda",
3131
"_harmony_scatter_cuda",
3232
"_hvg_cuda",
33+
"_kde_cuda",
3334
"_ligrec_cuda",
3435
"_mean_var_cuda",
3536
"_nanmean_cuda",
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "kernels_kde.cuh"
2+
#include "../nb_types.h"
3+
4+
using namespace nb::literals;
5+
6+
template <typename T>
7+
inline void launch_gaussian_kde_2d(const T* xy, T* out, int n, T a, T b, T c,
8+
cudaStream_t stream) {
9+
constexpr int threads = 256;
10+
const int blocks = (n + threads - 1) / threads;
11+
gaussian_kde_2d_kernel<<<blocks, threads, 0, stream>>>(xy, out, n, a, b, c);
12+
}
13+
14+
NB_MODULE(_kde_cuda, m) {
15+
m.def(
16+
"gaussian_kde_2d",
17+
[](cuda_array_c<const float> xy, cuda_array_c<float> out, int n,
18+
float a, float b, float c, std::uintptr_t stream) {
19+
launch_gaussian_kde_2d(xy.data(), out.data(), n, a, b, c,
20+
(cudaStream_t)stream);
21+
},
22+
"xy"_a, nb::kw_only(), "out"_a, "n"_a, "a"_a, "b"_a, "c"_a,
23+
"stream"_a = 0);
24+
25+
m.def(
26+
"gaussian_kde_2d",
27+
[](cuda_array_c<const double> xy, cuda_array_c<double> out, int n,
28+
double a, double b, double c, std::uintptr_t stream) {
29+
launch_gaussian_kde_2d(xy.data(), out.data(), n, a, b, c,
30+
(cudaStream_t)stream);
31+
},
32+
"xy"_a, nb::kw_only(), "out"_a, "n"_a, "a"_a, "b"_a, "c"_a,
33+
"stream"_a = 0);
34+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#pragma once
2+
3+
#include <cuda_runtime.h>
4+
#include <math_constants.h>
5+
6+
template <typename T>
7+
__device__ __forceinline__ T neg_infinity();
8+
9+
template <>
10+
__device__ __forceinline__ float neg_infinity<float>() {
11+
return -CUDART_INF_F;
12+
}
13+
14+
template <>
15+
__device__ __forceinline__ double neg_infinity<double>() {
16+
return -CUDART_INF;
17+
}
18+
19+
template <typename T>
20+
__global__ void gaussian_kde_2d_kernel(const T* __restrict__ xy,
21+
T* __restrict__ out, const int n,
22+
const T a, const T b, const T c) {
23+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
24+
if (i >= n) return;
25+
26+
const T xi = xy[2 * i];
27+
const T yi = xy[2 * i + 1];
28+
29+
T running_max = neg_infinity<T>();
30+
T running_sum = T(0);
31+
32+
for (int j = 0; j < n; j++) {
33+
const T dx = xi - xy[2 * j];
34+
const T dy = yi - xy[2 * j + 1];
35+
const T log_k = a * dx * dx + b * dx * dy + c * dy * dy;
36+
37+
if (log_k > running_max) {
38+
running_sum = running_sum * exp(running_max - log_k) + T(1);
39+
running_max = log_k;
40+
} else {
41+
running_sum += exp(log_k - running_max);
42+
}
43+
}
44+
45+
out[i] = log(running_sum) + running_max;
46+
}

src/rapids_singlecell/tools/_embedding_density.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import math
43
from typing import TYPE_CHECKING
54

65
import cupy as cp
@@ -20,7 +19,6 @@ def embedding_density(
2019
*,
2120
groupby: str | None = None,
2221
key_added: str | None = None,
23-
batchsize: int = 10000,
2422
components: str | Sequence[str] = None,
2523
) -> None:
2624
"""\
@@ -34,10 +32,6 @@ def embedding_density(
3432
the same category.
3533
This function was written by Sophie Tritschler and implemented into
3634
Scanpy by Malte Luecken.
37-
This function uses cuML's KernelDensity. It returns log Likelihood as does
38-
sklearn's implementation. scipy.stats implementation, used
39-
in scanpy, returns PDF.
40-
4135
Parameters
4236
----------
4337
adata
@@ -51,8 +45,6 @@ def embedding_density(
5145
key_added
5246
Name of the `.obs` covariate that will be added with the density
5347
estimates.
54-
batchsize
55-
Number of cells that should be processed together.
5648
components
5749
The embedding dimensions over which the density should be calculated.
5850
This is limited to two components.
@@ -76,7 +68,7 @@ def embedding_density(
7668
if basis == "fa":
7769
basis = "draw_graph_fa"
7870

79-
if f"X_{basis}" not in adata.obsm_keys():
71+
if f"X_{basis}" not in adata.obsm:
8072
raise ValueError(
8173
"Cannot find the embedded representation "
8274
f"`adata.obsm['X_{basis}']`. Compute the embedding first."
@@ -117,16 +109,16 @@ def embedding_density(
117109
embed_x = adata.obsm[f"X_{basis}"][cat_mask, components[0]]
118110
embed_y = adata.obsm[f"X_{basis}"][cat_mask, components[1]]
119111

120-
dens_embed = _calc_density(cp.array(embed_x), cp.array(embed_y), batchsize)
112+
dens_embed = _calc_density(cp.array(embed_x), cp.array(embed_y))
121113
density_values[cat_mask] = dens_embed
122114

123115
adata.obs[density_covariate] = density_values
124116
else: # if groupby is None
125117
# Calculate the density over the whole embedding without subsetting
126-
embed_x = adata.obsm[f"X_{basis}"][:, components[0]]
127-
embed_y = adata.obsm[f"X_{basis}"][:, components[1]]
118+
embed_x = cp.asarray(adata.obsm[f"X_{basis}"][:, components[0]])
119+
embed_y = cp.asarray(adata.obsm[f"X_{basis}"][:, components[1]])
128120

129-
adata.obs[density_covariate] = _calc_density(embed_x, embed_y, batchsize)
121+
adata.obs[density_covariate] = _calc_density(embed_x, embed_y)
130122

131123
# Reduce diffmap components for labeling
132124
# Note: plot_scatter takes care of correcting diffmap components
@@ -140,26 +132,47 @@ def embedding_density(
140132
}
141133

142134

143-
def _calc_density(x: cp.ndarray, y: cp.ndarray, batchsize: int):
135+
def _calc_density(x: cp.ndarray, y: cp.ndarray) -> np.ndarray:
144136
"""\
145-
Calculates the density of points in 2 dimensions.
137+
Calculates the density of points in 2 dimensions using a Gaussian KDE kernel.
138+
139+
Uses a covariance-aware bandwidth (Scott's rule) matching
140+
:class:`scipy.stats.gaussian_kde`, and min-max scales the PDF.
141+
Each GPU thread computes the log-density for one query point via an
142+
in-thread streaming logsumexp over all training points. No intermediate
143+
distance matrix is ever materialised.
146144
"""
147-
from cuml.neighbors import KernelDensity
148-
149-
# Calculate the point density
150-
xy = np.vstack([x, y]).T
151-
bandwidth = np.power(xy.shape[0], (-1.0 / (xy.shape[1] + 4)))
152-
kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth).fit(xy)
153-
z = cp.zeros(xy.shape[0])
154-
n_batches = math.ceil(xy.shape[0] / batchsize)
155-
for batch in range(n_batches):
156-
start_idx = batch * batchsize
157-
stop_idx = min(batch * batchsize + batchsize, xy.shape[0])
158-
z[start_idx:stop_idx] = cp.array(kde.score_samples(xy[start_idx:stop_idx, :]))
159-
min_z = cp.min(z)
160-
max_z = cp.max(z)
161-
162-
# Scale between 0 and 1
163-
scaled_z = (z - min_z) / (max_z - min_z)
164-
165-
return scaled_z.get()
145+
from rapids_singlecell._cuda import _kde_cuda
146+
147+
xy = cp.stack([x, y], axis=1) # (n, 2), C-contiguous
148+
n = xy.shape[0]
149+
dtype = xy.dtype
150+
151+
# Covariance-aware bandwidth matching scipy.stats.gaussian_kde
152+
scotts_factor = n ** (-1.0 / 6.0)
153+
data_cov = cp.cov(xy.T) # (2, 2)
154+
inv_cov = cp.linalg.inv(scotts_factor**2 * data_cov)
155+
156+
# Pre-multiply so the kernel just computes a·dx² + b·dx·dy + c·dy²
157+
a = -0.5 * float(inv_cov[0, 0])
158+
b = -float(inv_cov[0, 1])
159+
c = -0.5 * float(inv_cov[1, 1])
160+
161+
z = cp.empty(n, dtype=dtype)
162+
163+
_kde_cuda.gaussian_kde_2d(
164+
xy,
165+
out=z,
166+
n=n,
167+
a=a,
168+
b=b,
169+
c=c,
170+
stream=cp.cuda.get_current_stream().ptr,
171+
)
172+
173+
# Min-max scale PDF (not log-PDF) to match scipy/scanpy
174+
pdf = cp.exp(z)
175+
min_pdf = pdf.min()
176+
scaled = (pdf - min_pdf) / (pdf.max() - min_pdf)
177+
178+
return scaled.get()

tests/test_embedding_density.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pandas as pd
55
import pytest
6+
import scanpy as sc
67
from anndata import AnnData
78

89
import rapids_singlecell as rsc
@@ -153,3 +154,29 @@ def test_fa_alias():
153154

154155
rsc.tl.embedding_density(adata, "fa")
155156
assert "draw_graph_fa_density" in adata.obs.columns
157+
158+
159+
@pytest.fixture
160+
def pbmc68k():
161+
return sc.datasets.pbmc68k_reduced()
162+
163+
164+
@pytest.mark.parametrize("groupby", [None, "louvain"])
165+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
166+
def test_matches_scanpy(pbmc68k, groupby, dtype):
167+
"""GPU density matches scanpy on pbmc68k_reduced."""
168+
adata_sc = pbmc68k.copy()
169+
adata_sc.obsm["X_umap"] = adata_sc.obsm["X_umap"].astype(dtype)
170+
sc.tl.embedding_density(adata_sc, "umap", groupby=groupby)
171+
172+
adata_gpu = pbmc68k.copy()
173+
adata_gpu.obsm["X_umap"] = adata_gpu.obsm["X_umap"].astype(dtype)
174+
rsc.tl.embedding_density(adata_gpu, "umap", groupby=groupby)
175+
176+
key = "umap_density" if groupby is None else f"umap_density_{groupby}"
177+
atol = 1e-6 if dtype == np.float32 else 1e-12
178+
np.testing.assert_allclose(
179+
adata_gpu.obs[key].values,
180+
adata_sc.obs[key].values,
181+
atol=atol,
182+
)

0 commit comments

Comments
 (0)