66import numpy as np
77from scipy .sparse import (
88 csr_matrix ,
9+ csr_array ,
910 issparse ,
1011 SparseEfficiencyWarning ,
1112 linalg ,
2122import scanpy
2223from scanpy import logging
2324from scanpy .tools ._utils import _choose_representation
24- from umap .distances import euclidean
25- from umap .sparse import sparse_euclidean , sparse_jaccard
25+ from pynndescent .distances import euclidean
26+ from pynndescent .sparse import sparse_euclidean , sparse_jaccard
2627from umap .umap_ import nearest_neighbors
2728from numba import njit , prange
2829
4142
4243# Computational methods for preprocessing
4344
44- _euclidean = njit (euclidean . py_func , inline = "always" , fastmath = True )
45- _sparse_euclidean = njit (sparse_euclidean . py_func , inline = "always" )
46- _sparse_jaccard = njit (sparse_jaccard . py_func , inline = "always" )
45+ _euclidean = njit (getattr ( euclidean , " py_func" , euclidean ) , inline = "always" , fastmath = True )
46+ _sparse_euclidean = njit (getattr ( sparse_euclidean , " py_func" , sparse_euclidean ) , inline = "always" )
47+ _sparse_jaccard = njit (getattr ( sparse_jaccard , " py_func" , sparse_jaccard ) , inline = "always" )
4748
4849
4950@njit
@@ -92,13 +93,17 @@ def _jaccard_sparse_euclidean_metric(
9293 if x == y :
9394 return N + 1.0
9495
95- from_inds = X_indices [ X_indptr [x ] : X_indptr [x + 1 ]]
96- from_data = X_data [ X_indptr [x ] : X_indptr [x + 1 ]]
97- to_inds = X_indices [ X_indptr [y ] : X_indptr [y + 1 ]]
98- to_data = X_data [ X_indptr [y ] : X_indptr [y + 1 ]]
96+ from_inds = neighbors_indices [ neighbors_indptr [x ] : neighbors_indptr [x + 1 ]]
97+ from_data = neighbors_data [ neighbors_indptr [x ] : neighbors_indptr [x + 1 ]]
98+ to_inds = neighbors_indices [ neighbors_indptr [y ] : neighbors_indptr [y + 1 ]]
99+ to_data = neighbors_data [ neighbors_indptr [y ] : neighbors_indptr [y + 1 ]]
99100 jac = _sparse_jaccard (from_inds , from_data , to_inds , to_data )
100101
101102 if jac < 1.0 :
103+ from_inds = X_indices [X_indptr [x ] : X_indptr [x + 1 ]]
104+ from_data = X_data [X_indptr [x ] : X_indptr [x + 1 ]]
105+ to_inds = X_indices [X_indptr [y ] : X_indptr [y + 1 ]]
106+ to_data = X_data [X_indptr [y ] : X_indptr [y + 1 ]]
102107 euclidean = _sparse_euclidean (from_inds , from_data , to_inds , to_data )
103108 return (N - jac * N ) + (bbox_norm - euclidean ) / bbox_norm
104109 else :
@@ -130,12 +135,12 @@ def _sparse_csr_fast_knn_(
130135
131136
132137# numba doesn't know about SciPy
133- def _sparse_csr_fast_knn (X : csr_matrix , n_neighbors : int ):
138+ def _sparse_csr_fast_knn (X : csr_matrix | csr_array , n_neighbors : int ):
134139 data , indices , indptr = _sparse_csr_fast_knn_ (
135140 X .shape [0 ], X .indptr , X .indices , X .data , n_neighbors
136141 )
137142 indptr = np .concatenate ((indptr , (indices .size ,)))
138- return csr_matrix ((data , indices , indptr ), X .shape )
143+ return csr_array ((data , indices , indptr ), X .shape )
139144
140145
141146@njit (parallel = True )
@@ -149,7 +154,7 @@ def _sparse_csr_ptp_(N: int, indptr: np.ndarray, indices: np.ndarray, data: np.n
149154 return maxelems - minelems
150155
151156
152- def _sparse_csr_ptp (X : csr_matrix ):
157+ def _sparse_csr_ptp (X : csr_matrix | csr_array ):
153158 return _sparse_csr_ptp_ (X .shape [1 ], X .indptr , X .indices , X .data )
154159
155160
@@ -179,7 +184,7 @@ def _l2norm(
179184 X_norm = linalg .norm (X , ord = 2 , axis = 1 )
180185 norm = X / np .expand_dims (X_norm , axis = 1 )
181186 if not issparse (norm ):
182- norm = csr_matrix (norm )
187+ norm = csr_array (norm )
183188 norm .data [~ np .isfinite (norm .data )] = 0
184189 else :
185190 norm = X / np .linalg .norm (X , ord = 2 , axis = 1 , keepdims = True )
@@ -498,7 +503,7 @@ def neighbors(
498503 sigmas [mod1 ] = csigmas
499504
500505 weights = softmax (ratios , axis = 1 )
501- neighbordistances = csr_matrix ((mdata .n_obs , mdata .n_obs ), dtype = np .float64 )
506+ neighbordistances = csr_array ((mdata .n_obs , mdata .n_obs ), dtype = np .float64 )
502507 largeidx = mdata .n_obs ** 2 > np .iinfo (np .int32 ).max
503508 if largeidx : # work around scipy bug https://github.com/scipy/scipy/issues/13155
504509 neighbordistances .indptr = neighbordistances .indptr .astype (np .int64 )
@@ -520,7 +525,7 @@ def neighbors(
520525 angular = False ,
521526 low_memory = lmemory ,
522527 )
523- graph = csr_matrix (
528+ graph = csr_array (
524529 (
525530 distances [:, 1 :].reshape (- 1 ),
526531 nn_indices [:, 1 :].reshape (- 1 ),
@@ -578,12 +583,14 @@ def neighbors(
578583 if issparse (rep ):
579584
580585 def neighdist (cell , nz ):
581- return - cdist (rep [cell , :].toarray (), rep [nz , :].toarray (), metric = metric )
586+ return - cdist (
587+ rep [cell , :].toarray ()[None , ...], rep [nz , :].toarray (), metric = metric
588+ )
582589
583590 else :
584591
585592 def neighdist (cell , nz ):
586- return - cdist (rep [np . newaxis , cell , :], rep [nz , :], metric = metric )
593+ return - cdist (rep [None , cell , :], rep [nz , :], metric = metric )
587594
588595 for cell , j in enumerate (fullidx ):
589596 row = slice (neighbordistances .indptr [cell ], neighbordistances .indptr [cell + 1 ])
0 commit comments