Skip to content

Commit 97b767d

Browse files
committed
WNN: make sparse input actually work (closes #173)
1 parent f46ce8c commit 97b767d

1 file changed

Lines changed: 24 additions & 17 deletions

File tree

muon/_core/preproc.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from scipy.sparse import (
88
csr_matrix,
9+
csr_array,
910
issparse,
1011
SparseEfficiencyWarning,
1112
linalg,
@@ -21,8 +22,8 @@
2122
import scanpy
2223
from scanpy import logging
2324
from scanpy.tools._utils import _choose_representation
24-
from umap.distances import euclidean
25-
from umap.sparse import sparse_euclidean, sparse_jaccard
25+
from pynndescent.distances import euclidean
26+
from pynndescent.sparse import sparse_euclidean, sparse_jaccard
2627
from umap.umap_ import nearest_neighbors
2728
from numba import njit, prange
2829

@@ -41,9 +42,9 @@
4142

4243
# Computational methods for preprocessing
4344

44-
_euclidean = njit(euclidean.py_func, inline="always", fastmath=True)
45-
_sparse_euclidean = njit(sparse_euclidean.py_func, inline="always")
46-
_sparse_jaccard = njit(sparse_jaccard.py_func, inline="always")
45+
_euclidean = njit(getattr(euclidean, "py_func", euclidean), inline="always", fastmath=True)
46+
_sparse_euclidean = njit(getattr(sparse_euclidean, "py_func", sparse_euclidean), inline="always")
47+
_sparse_jaccard = njit(getattr(sparse_jaccard, "py_func", sparse_jaccard), inline="always")
4748

4849

4950
@njit
@@ -92,13 +93,17 @@ def _jaccard_sparse_euclidean_metric(
9293
if x == y:
9394
return N + 1.0
9495

95-
from_inds = X_indices[X_indptr[x] : X_indptr[x + 1]]
96-
from_data = X_data[X_indptr[x] : X_indptr[x + 1]]
97-
to_inds = X_indices[X_indptr[y] : X_indptr[y + 1]]
98-
to_data = X_data[X_indptr[y] : X_indptr[y + 1]]
96+
from_inds = neighbors_indices[neighbors_indptr[x] : neighbors_indptr[x + 1]]
97+
from_data = neighbors_data[neighbors_indptr[x] : neighbors_indptr[x + 1]]
98+
to_inds = neighbors_indices[neighbors_indptr[y] : neighbors_indptr[y + 1]]
99+
to_data = neighbors_data[neighbors_indptr[y] : neighbors_indptr[y + 1]]
99100
jac = _sparse_jaccard(from_inds, from_data, to_inds, to_data)
100101

101102
if jac < 1.0:
103+
from_inds = X_indices[X_indptr[x] : X_indptr[x + 1]]
104+
from_data = X_data[X_indptr[x] : X_indptr[x + 1]]
105+
to_inds = X_indices[X_indptr[y] : X_indptr[y + 1]]
106+
to_data = X_data[X_indptr[y] : X_indptr[y + 1]]
102107
euclidean = _sparse_euclidean(from_inds, from_data, to_inds, to_data)
103108
return (N - jac * N) + (bbox_norm - euclidean) / bbox_norm
104109
else:
@@ -130,12 +135,12 @@ def _sparse_csr_fast_knn_(
130135

131136

132137
# numba doesn't know about SciPy
133-
def _sparse_csr_fast_knn(X: csr_matrix, n_neighbors: int):
138+
def _sparse_csr_fast_knn(X: csr_matrix | csr_array, n_neighbors: int):
134139
data, indices, indptr = _sparse_csr_fast_knn_(
135140
X.shape[0], X.indptr, X.indices, X.data, n_neighbors
136141
)
137142
indptr = np.concatenate((indptr, (indices.size,)))
138-
return csr_matrix((data, indices, indptr), X.shape)
143+
return csr_array((data, indices, indptr), X.shape)
139144

140145

141146
@njit(parallel=True)
@@ -149,7 +154,7 @@ def _sparse_csr_ptp_(N: int, indptr: np.ndarray, indices: np.ndarray, data: np.n
149154
return maxelems - minelems
150155

151156

152-
def _sparse_csr_ptp(X: csr_matrix):
157+
def _sparse_csr_ptp(X: csr_matrix | csr_array):
153158
return _sparse_csr_ptp_(X.shape[1], X.indptr, X.indices, X.data)
154159

155160

@@ -179,7 +184,7 @@ def _l2norm(
179184
X_norm = linalg.norm(X, ord=2, axis=1)
180185
norm = X / np.expand_dims(X_norm, axis=1)
181186
if not issparse(norm):
182-
norm = csr_matrix(norm)
187+
norm = csr_array(norm)
183188
norm.data[~np.isfinite(norm.data)] = 0
184189
else:
185190
norm = X / np.linalg.norm(X, ord=2, axis=1, keepdims=True)
@@ -498,7 +503,7 @@ def neighbors(
498503
sigmas[mod1] = csigmas
499504

500505
weights = softmax(ratios, axis=1)
501-
neighbordistances = csr_matrix((mdata.n_obs, mdata.n_obs), dtype=np.float64)
506+
neighbordistances = csr_array((mdata.n_obs, mdata.n_obs), dtype=np.float64)
502507
largeidx = mdata.n_obs**2 > np.iinfo(np.int32).max
503508
if largeidx: # work around scipy bug https://github.com/scipy/scipy/issues/13155
504509
neighbordistances.indptr = neighbordistances.indptr.astype(np.int64)
@@ -520,7 +525,7 @@ def neighbors(
520525
angular=False,
521526
low_memory=lmemory,
522527
)
523-
graph = csr_matrix(
528+
graph = csr_array(
524529
(
525530
distances[:, 1:].reshape(-1),
526531
nn_indices[:, 1:].reshape(-1),
@@ -578,12 +583,14 @@ def neighbors(
578583
if issparse(rep):
579584

580585
def neighdist(cell, nz):
581-
return -cdist(rep[cell, :].toarray(), rep[nz, :].toarray(), metric=metric)
586+
return -cdist(
587+
rep[cell, :].toarray()[None, ...], rep[nz, :].toarray(), metric=metric
588+
)
582589

583590
else:
584591

585592
def neighdist(cell, nz):
586-
return -cdist(rep[np.newaxis, cell, :], rep[nz, :], metric=metric)
593+
return -cdist(rep[None, cell, :], rep[nz, :], metric=metric)
587594

588595
for cell, j in enumerate(fullidx):
589596
row = slice(neighbordistances.indptr[cell], neighbordistances.indptr[cell + 1])

0 commit comments

Comments
 (0)