11from __future__ import annotations
22
3- import math
43from typing import TYPE_CHECKING
54
65import cupy as cp
@@ -20,7 +19,6 @@ def embedding_density(
2019 * ,
2120 groupby : str | None = None ,
2221 key_added : str | None = None ,
23- batchsize : int = 10000 ,
2422 components : str | Sequence [str ] = None ,
2523) -> None :
2624 """\
@@ -34,10 +32,6 @@ def embedding_density(
3432 the same category.
3533 This function was written by Sophie Tritschler and implemented into
3634 Scanpy by Malte Luecken.
37- This function uses cuML's KernelDensity. It returns log Likelihood as does
38- sklearn's implementation. scipy.stats implementation, used
39- in scanpy, returns PDF.
40-
4135 Parameters
4236 ----------
4337 adata
@@ -51,8 +45,6 @@ def embedding_density(
5145 key_added
5246 Name of the `.obs` covariate that will be added with the density
5347 estimates.
54- batchsize
55- Number of cells that should be processed together.
5648 components
5749 The embedding dimensions over which the density should be calculated.
5850 This is limited to two components.
@@ -76,7 +68,7 @@ def embedding_density(
7668 if basis == "fa" :
7769 basis = "draw_graph_fa"
7870
79- if f"X_{ basis } " not in adata .obsm_keys () :
71+ if f"X_{ basis } " not in adata .obsm :
8072 raise ValueError (
8173 "Cannot find the embedded representation "
8274 f"`adata.obsm['X_{ basis } ']`. Compute the embedding first."
@@ -117,16 +109,16 @@ def embedding_density(
117109 embed_x = adata .obsm [f"X_{ basis } " ][cat_mask , components [0 ]]
118110 embed_y = adata .obsm [f"X_{ basis } " ][cat_mask , components [1 ]]
119111
120- dens_embed = _calc_density (cp .array (embed_x ), cp .array (embed_y ), batchsize )
112+ dens_embed = _calc_density (cp .array (embed_x ), cp .array (embed_y ))
121113 density_values [cat_mask ] = dens_embed
122114
123115 adata .obs [density_covariate ] = density_values
124116 else : # if groupby is None
125117 # Calculate the density over the whole embedding without subsetting
126- embed_x = adata .obsm [f"X_{ basis } " ][:, components [0 ]]
127- embed_y = adata .obsm [f"X_{ basis } " ][:, components [1 ]]
118+ embed_x = cp . asarray ( adata .obsm [f"X_{ basis } " ][:, components [0 ]])
119+ embed_y = cp . asarray ( adata .obsm [f"X_{ basis } " ][:, components [1 ]])
128120
129- adata .obs [density_covariate ] = _calc_density (embed_x , embed_y , batchsize )
121+ adata .obs [density_covariate ] = _calc_density (embed_x , embed_y )
130122
131123 # Reduce diffmap components for labeling
132124 # Note: plot_scatter takes care of correcting diffmap components
@@ -140,26 +132,47 @@ def embedding_density(
140132 }
141133
142134
143- def _calc_density (x : cp .ndarray , y : cp .ndarray , batchsize : int ) :
135+ def _calc_density (x : cp .ndarray , y : cp .ndarray ) -> np . ndarray :
144136 """\
145- Calculates the density of points in 2 dimensions.
137+ Calculates the density of points in 2 dimensions using a Gaussian KDE kernel.
138+
139+ Uses a covariance-aware bandwidth (Scott's rule) matching
140+ :class:`scipy.stats.gaussian_kde`, and min-max scales the PDF.
141+ Each GPU thread computes the log-density for one query point via an
142+ in-thread streaming logsumexp over all training points. No intermediate
143+ distance matrix is ever materialised.
146144 """
147- from cuml .neighbors import KernelDensity
148-
149- # Calculate the point density
150- xy = np .vstack ([x , y ]).T
151- bandwidth = np .power (xy .shape [0 ], (- 1.0 / (xy .shape [1 ] + 4 )))
152- kde = KernelDensity (kernel = "gaussian" , bandwidth = bandwidth ).fit (xy )
153- z = cp .zeros (xy .shape [0 ])
154- n_batches = math .ceil (xy .shape [0 ] / batchsize )
155- for batch in range (n_batches ):
156- start_idx = batch * batchsize
157- stop_idx = min (batch * batchsize + batchsize , xy .shape [0 ])
158- z [start_idx :stop_idx ] = cp .array (kde .score_samples (xy [start_idx :stop_idx , :]))
159- min_z = cp .min (z )
160- max_z = cp .max (z )
161-
162- # Scale between 0 and 1
163- scaled_z = (z - min_z ) / (max_z - min_z )
164-
165- return scaled_z .get ()
145+ from rapids_singlecell ._cuda import _kde_cuda
146+
147+ xy = cp .stack ([x , y ], axis = 1 ) # (n, 2), C-contiguous
148+ n = xy .shape [0 ]
149+ dtype = xy .dtype
150+
151+ # Covariance-aware bandwidth matching scipy.stats.gaussian_kde
152+ scotts_factor = n ** (- 1.0 / 6.0 )
153+ data_cov = cp .cov (xy .T ) # (2, 2)
154+ inv_cov = cp .linalg .inv (scotts_factor ** 2 * data_cov )
155+
156+ # Pre-multiply so the kernel just computes a·dx² + b·dx·dy + c·dy²
157+ a = - 0.5 * float (inv_cov [0 , 0 ])
158+ b = - float (inv_cov [0 , 1 ])
159+ c = - 0.5 * float (inv_cov [1 , 1 ])
160+
161+ z = cp .empty (n , dtype = dtype )
162+
163+ _kde_cuda .gaussian_kde_2d (
164+ xy ,
165+ out = z ,
166+ n = n ,
167+ a = a ,
168+ b = b ,
169+ c = c ,
170+ stream = cp .cuda .get_current_stream ().ptr ,
171+ )
172+
173+ # Min-max scale PDF (not log-PDF) to match scipy/scanpy
174+ pdf = cp .exp (z )
175+ min_pdf = pdf .min ()
176+ scaled = (pdf - min_pdf ) / (pdf .max () - min_pdf )
177+
178+ return scaled .get ()
0 commit comments