55import cuml
66import cuml .internals .logger as logger
77import cupy as cp
8+ import numpy as np
89from cuml .manifold .simpl_set import simplicial_set_embedding
910from cuml .manifold .umap import UMAP
1011from cuml .manifold .umap_utils import find_ab_params
12+ from cuml .thirdparty_adapters import check_array as check_array_cuml
1113from cupyx .scipy import sparse
1214from packaging .version import parse as parse_version
1315from scanpy ._utils import NeighborsView
16+ from scanpy .tools ._utils import get_init_pos_from_paga
1417from sklearn .utils import check_random_state
1518
1619from rapids_singlecell ._utils import _get_logger_level
2023if TYPE_CHECKING :
2124 from anndata import AnnData
2225
23- _InitPos = Literal ["auto" , "spectral" , "random" ]
26+ _InitPos = Literal ["auto" , "spectral" , "random" , "paga" ]
2427
2528
2629def umap (
@@ -32,7 +35,7 @@ def umap(
3235 maxiter : int | None = None ,
3336 alpha : float = 1.0 ,
3437 negative_sample_rate : int = 5 ,
35- init_pos : _InitPos = "auto" ,
38+ init_pos : _InitPos | np . ndarray | cp . ndarray | str | None = "auto" ,
3639 random_state : int = 0 ,
3740 a : float | None = None ,
3841 b : float | None = None ,
@@ -82,6 +85,9 @@ def umap(
8285 * 'auto': chooses 'spectral' for `'n_samples' < 1000000`, 'random' otherwise.
8386 * 'spectral': use a spectral embedding of the graph.
8487 * 'random': assign initial embedding positions at random.
88+ * 'paga': use the :func:`~scanpy.tl.paga` layout as initial embedding positions.
89+ * Array of shape (n_obs, 2)
90+ * Any key for :attr:`~anndata.AnnData.obsm`
8591
8692 .. note::
8793 If your embedding looks odd it's recommended setting `init_pos` to 'random'.
@@ -143,8 +149,6 @@ def umap(
143149 ** ({"random_state" : random_state } if random_state != 0 else {}),
144150 }
145151
146- random_state = check_random_state (random_state )
147-
148152 neigh_params = neighbors ["params" ]
149153 X = _choose_representation (
150154 adata ,
@@ -167,6 +171,14 @@ def umap(
167171 else :
168172 pre_knn = None
169173
174+ if init_pos not in ["auto" , "spectral" , "random" ]:
175+ raise ValueError (
176+ f"Invalid init_pos: { init_pos } " ,
177+ "Valid options are: auto, spectral, random, paga for RAPIDS < 24.10" ,
178+ )
179+
180+ random_state = check_random_state (random_state )
181+
170182 if init_pos == "auto" :
171183 init_pos = "spectral" if n_obs < 1000000 else "random"
172184
@@ -192,8 +204,25 @@ def umap(
192204 else :
193205 pre_knn = neighbors ["connectivities" ]
194206
195- if init_pos == "auto" :
196- init_pos = "spectral" if n_obs < 1000000 else "random"
207+ match init_pos :
208+ case str () if init_pos in adata .obsm :
209+ init_coords = adata .obsm [init_pos ]
210+ case str () if init_pos == "paga" :
211+ init_coords = get_init_pos_from_paga (
212+ adata , random_state = random_state , neighbors_key = neighbors_key
213+ )
214+ case str () if init_pos == "auto" :
215+ init_coords = "spectral" if n_obs < 1000000 else "random"
216+ case _:
217+ init_coords = init_pos
218+
219+ if hasattr (init_coords , "dtype" ):
220+ init_coords = check_array_cuml (
221+ init_coords , dtype = np .float32 , accept_sparse = False
222+ )
223+
224+ random_state = check_random_state (random_state )
225+
197226 logger_level = _get_logger_level (logger )
198227 X_umap = simplicial_set_embedding (
199228 data = cp .array (X ),
@@ -204,7 +233,7 @@ def umap(
204233 b = b ,
205234 negative_sample_rate = negative_sample_rate ,
206235 n_epochs = n_epochs ,
207- init = init_pos ,
236+ init = init_coords ,
208237 random_state = random_state ,
209238 metric = neigh_params .get ("metric" , "euclidean" ),
210239 metric_kwds = neigh_params .get ("metric_kwds" , None ),
0 commit comments