Skip to content

Commit 178dffc

Browse files
authored
Drop cuml.thirdparty_adapters.check_array dependency (#660)
* replace cuml util * add release note
1 parent 2926263 commit 178dffc

4 files changed

Lines changed: 20 additions & 9 deletions

File tree

docs/release-notes/0.15.1.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
* ``adata.uns[key_added]["params"]["resolution"]`` is now stored as a scalar ``float`` when a single resolution
1717
is passed to ``tl.leiden`` and ``tl.louvain`` to match behaviour in Scanpy, and as a ``list`` when multiple
1818
resolutions are passed. Previously it was always stored as a list. {pr}`648`. {smaller}`J Pintar`
19+
* Drop dependency on ``cuml.thirdparty_adapters.check_array`` (removed in cuml 26.06); ``init_pos`` validation in ``tl.umap`` and ``tl.draw_graph`` is now handled locally {pr}`660` {smaller}`S Dicks`

src/rapids_singlecell/tools/_draw_graph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import cudf
66
import cupy as cp
77
import numpy as np
8-
from cuml.thirdparty_adapters import check_array as check_array_cuml
98
from scanpy.tools._utils import get_init_pos_from_paga
109

1110
from rapids_singlecell._compat import _random_state_kwargs
1211

1312
from ._clustering import _create_graph
13+
from ._utils import _validate_init_pos
1414

1515
if TYPE_CHECKING:
1616
from anndata import AnnData
@@ -67,9 +67,7 @@ def draw_graph(
6767
case _:
6868
init_coords = init_pos
6969
if hasattr(init_coords, "dtype"):
70-
init_coords = check_array_cuml(
71-
init_coords, dtype=np.float32, accept_sparse=False
72-
)
70+
init_coords = _validate_init_pos(init_coords)
7371
if init_coords.shape[1] != 2:
7472
raise ValueError(
7573
f"Expected 2 columns but got {init_coords.shape[1]} columns."

src/rapids_singlecell/tools/_umap.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import cupy as cp
88
import numpy as np
99
from cuml.manifold.umap import UMAP, find_ab_params, simplicial_set_embedding
10-
from cuml.thirdparty_adapters import check_array as check_array_cuml
1110
from cupyx.scipy import sparse
1211
from packaging.version import parse as parse_version
1312
from scanpy._utils import NeighborsView
@@ -17,7 +16,7 @@
1716
from rapids_singlecell._compat import _random_state_kwargs
1817
from rapids_singlecell._utils import _get_logger_level
1918

20-
from ._utils import _choose_representation
19+
from ._utils import _choose_representation, _validate_init_pos
2120

2221
if TYPE_CHECKING:
2322
from anndata import AnnData
@@ -216,9 +215,7 @@ def umap(
216215
init_coords = init_pos
217216

218217
if hasattr(init_coords, "dtype"):
219-
init_coords = check_array_cuml(
220-
init_coords, dtype=np.float32, accept_sparse=False
221-
)
218+
init_coords = _validate_init_pos(init_coords)
222219

223220
random_state = check_random_state(random_state)
224221

src/rapids_singlecell/tools/_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
11
from __future__ import annotations
22

33
import cupy as cp
4+
import scipy.sparse as cpu_sparse
45
from cupyx.scipy.sparse import issparse, isspmatrix_csc, isspmatrix_csr
56

67
from rapids_singlecell._compat import DaskArray
78

89
from . import pca
910

1011

12+
def _validate_init_pos(init_coords):
13+
"""Coerce a user-supplied `init_pos` array to a 2D cupy float32 array.
14+
15+
Replaces the previous use of ``cuml.thirdparty_adapters.check_array``, which was
16+
removed in cuml 26.06.
17+
"""
18+
if cpu_sparse.issparse(init_coords) or issparse(init_coords):
19+
raise ValueError("Sparse `init_pos` is not supported.")
20+
arr = cp.asarray(init_coords, dtype=cp.float32)
21+
if arr.ndim != 2:
22+
raise ValueError(f"Expected 2D `init_pos`, got {arr.ndim}D array.")
23+
return arr
24+
25+
1126
def _choose_representation(adata, use_rep=None, n_pcs=None):
1227
if use_rep is None and n_pcs == 0: # backwards compat for specifying `.X`
1328
use_rep = "X"

0 commit comments

Comments
 (0)