Skip to content

Commit 781be59

Browse files
add init_pos to umap (#427)
* add test and implementation * add igraph for testing * pytest full * update marks * Update src/rapids_singlecell/tools/_umap.py Co-authored-by: Philipp A. <flying-sheep@web.de> * fix string * adds release note * test only minimal --------- Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent 209df0e commit 781be59

19 files changed

Lines changed: 145 additions & 25 deletions

.github/workflows/test-gpu-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949

5050
- name: Install rapids-singlecell
5151
run: >-
52-
pip install -e .[test]
52+
pip install -e .[test-minimal]
5353
"scanpy @ git+https://github.com/scverse/scanpy.git"
5454
"anndata @ git+https://github.com/scverse/anndata.git"
5555

docs/release-notes/0.13.1.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
```{rubric} Features
44
```
5+
* adds support for `tl.umap` to support `init_pos` in the form of `ndarray`, `paga` or `obsm[key]` {pr}`427` {smaller}`S Dicks`
56

67
```{rubric} Performance
78
```
@@ -14,3 +15,4 @@
1415

1516
```{rubric} Misc
1617
```
18+
* refactors `testing_utils` {pr}`427` {smaller}`S Dicks`

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,18 @@ doc = [
4040
"dask",
4141
"pytest",
4242
]
43-
test = [
43+
test-minimal = [
4444
"pytest",
4545
"profimp",
4646
"scanpy>=1.10.0",
4747
"bbknn",
4848
"decoupler",
4949
"fast-array-utils",
5050
]
51+
test = [
52+
"rapids_singlecell[test-minimal]",
53+
"igraph",
54+
]
5155

5256
[project.urls]
5357
Documentation = "https://rapids-singlecell.readthedocs.io"
@@ -88,8 +92,6 @@ lint.ignore = [
8892
"docs/*" = [ "I" ]
8993
"tests/*" = [ "D" ]
9094
"*/__init__.py" = [ "F401" ]
91-
"src/rapids_singlecell/decoupler_gpu/_method_mlm.py" = [ "PLR0917" ]
92-
"src/rapids_singlecell/decoupler_gpu/_method_wsum.py" = [ "PLR0917" ]
9395
[tool.ruff.lint.isort]
9496
known-first-party = [ "rapids_singlecell" ]
9597
required-imports = [ "from __future__ import annotations" ]
@@ -106,7 +108,6 @@ markers = [
106108
[tool.hatch.build]
107109
# exclude big files that don’t need to be installed
108110
exclude = [
109-
"src/rapids_singlecell/_testing.py",
110111
"tests",
111112
"docs",
112113
"notebooks",
@@ -118,7 +119,7 @@ version-file = "src/rapids_singlecell/_version.py"
118119
source = "vcs"
119120

120121
[tool.hatch.build.targets.wheel]
121-
packages = [ 'src/rapids_singlecell' ]
122+
packages = [ 'src/rapids_singlecell', 'src/testing' ]
122123

123124
[tool.codespell]
124125
skip = '*.ipynb,*.csv'

src/rapids_singlecell/tools/_umap.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
import cuml
66
import cuml.internals.logger as logger
77
import cupy as cp
8+
import numpy as np
89
from cuml.manifold.simpl_set import simplicial_set_embedding
910
from cuml.manifold.umap import UMAP
1011
from cuml.manifold.umap_utils import find_ab_params
12+
from cuml.thirdparty_adapters import check_array as check_array_cuml
1113
from cupyx.scipy import sparse
1214
from packaging.version import parse as parse_version
1315
from scanpy._utils import NeighborsView
16+
from scanpy.tools._utils import get_init_pos_from_paga
1417
from sklearn.utils import check_random_state
1518

1619
from rapids_singlecell._utils import _get_logger_level
@@ -20,7 +23,7 @@
2023
if TYPE_CHECKING:
2124
from anndata import AnnData
2225

23-
_InitPos = Literal["auto", "spectral", "random"]
26+
_InitPos = Literal["auto", "spectral", "random", "paga"]
2427

2528

2629
def umap(
@@ -32,7 +35,7 @@ def umap(
3235
maxiter: int | None = None,
3336
alpha: float = 1.0,
3437
negative_sample_rate: int = 5,
35-
init_pos: _InitPos = "auto",
38+
init_pos: _InitPos | np.ndarray | cp.ndarray | str | None = "auto",
3639
random_state: int = 0,
3740
a: float | None = None,
3841
b: float | None = None,
@@ -82,6 +85,9 @@ def umap(
8285
* 'auto': chooses 'spectral' for `'n_samples' < 1000000`, 'random' otherwise.
8386
* 'spectral': use a spectral embedding of the graph.
8487
* 'random': assign initial embedding positions at random.
88+
* 'paga': use the :func:`~scanpy.tl.paga` layout as initial embedding positions.
89+
* Array of shape (n_obs, 2)
90+
* Any key for :attr:`~anndata.AnnData.obsm`
8591
8692
.. note::
8793
If your embedding looks odd it's recommended setting `init_pos` to 'random'.
@@ -143,8 +149,6 @@ def umap(
143149
**({"random_state": random_state} if random_state != 0 else {}),
144150
}
145151

146-
random_state = check_random_state(random_state)
147-
148152
neigh_params = neighbors["params"]
149153
X = _choose_representation(
150154
adata,
@@ -167,6 +171,14 @@ def umap(
167171
else:
168172
pre_knn = None
169173

174+
if init_pos not in ["auto", "spectral", "random"]:
175+
raise ValueError(
176+
f"Invalid init_pos: {init_pos}",
177+
"Valid options are: auto, spectral, random, paga for RAPIDS < 24.10",
178+
)
179+
180+
random_state = check_random_state(random_state)
181+
170182
if init_pos == "auto":
171183
init_pos = "spectral" if n_obs < 1000000 else "random"
172184

@@ -192,8 +204,25 @@ def umap(
192204
else:
193205
pre_knn = neighbors["connectivities"]
194206

195-
if init_pos == "auto":
196-
init_pos = "spectral" if n_obs < 1000000 else "random"
207+
match init_pos:
208+
case str() if init_pos in adata.obsm:
209+
init_coords = adata.obsm[init_pos]
210+
case str() if init_pos == "paga":
211+
init_coords = get_init_pos_from_paga(
212+
adata, random_state=random_state, neighbors_key=neighbors_key
213+
)
214+
case str() if init_pos == "auto":
215+
init_coords = "spectral" if n_obs < 1000000 else "random"
216+
case _:
217+
init_coords = init_pos
218+
219+
if hasattr(init_coords, "dtype"):
220+
init_coords = check_array_cuml(
221+
init_coords, dtype=np.float32, accept_sparse=False
222+
)
223+
224+
random_state = check_random_state(random_state)
225+
197226
logger_level = _get_logger_level(logger)
198227
X_umap = simplicial_set_embedding(
199228
data=cp.array(X),
@@ -204,7 +233,7 @@ def umap(
204233
b=b,
205234
negative_sample_rate=negative_sample_rate,
206235
n_epochs=n_epochs,
207-
init=init_pos,
236+
init=init_coords,
208237
random_state=random_state,
209238
metric=neigh_params.get("metric", "euclidean"),
210239
metric_kwds=neigh_params.get("metric_kwds", None),
File renamed without changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from __future__ import annotations
2+
3+
from .marks import needs
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
from enum import Enum, auto
4+
from importlib.util import find_spec
5+
from typing import TYPE_CHECKING
6+
7+
import pytest
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
12+
13+
SKIP_EXTRA: dict[str, Callable[[], str | None]] = {}
14+
15+
16+
class QuietMarkDecorator(pytest.MarkDecorator):
17+
def __init__(self, mark: pytest.Mark) -> None:
18+
super().__init__(mark, _ispytest=True)
19+
20+
21+
class needs(QuietMarkDecorator, Enum):
22+
"""Pytest skip marker evaluated at module import.
23+
24+
This allows us to see the amount of skipped tests at the start of a test run.
25+
:func:`pytest.importorskip` skips tests after they started running.
26+
"""
27+
28+
# _generate_next_value_ needs to come before members
29+
@staticmethod
30+
def _generate_next_value_(
31+
name: str, start: int, count: int, last_values: list[str]
32+
) -> str:
33+
"""Distribution name for matching modules."""
34+
return name.replace("_", "-")
35+
36+
mod: str
37+
38+
igraph = auto()
39+
40+
def __init__(self, mod: str) -> None:
41+
self.mod = mod
42+
reason = self.skip_reason
43+
dec = pytest.mark.skipif(bool(reason), reason=reason or "")
44+
super().__init__(dec.mark)
45+
46+
@property
47+
def skip_reason(self) -> str | None:
48+
if find_spec(self._name_):
49+
if skip_extra := SKIP_EXTRA.get(self._name_):
50+
return skip_extra()
51+
return None
52+
reason = f"needs module `{self._name_}`"
53+
if self._name_.casefold() != self.mod.casefold().replace("-", "_"):
54+
reason = f"{reason} (`pip install {self.mod}`)"
55+
return reason

tests/dask/test_dask_aggr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from scanpy.datasets import pbmc3k_processed
88

99
import rapids_singlecell as rsc
10-
from rapids_singlecell._testing import (
10+
from testing.rapids_singlecell._helper import (
1111
as_dense_cupy_dask_array,
1212
as_sparse_cupy_dask_array,
1313
)

tests/dask/test_dask_mean_var.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from scanpy.datasets import pbmc3k, pbmc68k_reduced
66

77
import rapids_singlecell as rsc
8-
from rapids_singlecell._testing import (
8+
from rapids_singlecell.preprocessing._utils import _get_mean_var
9+
from testing.rapids_singlecell._helper import (
910
as_dense_cupy_dask_array,
1011
as_sparse_cupy_dask_array,
1112
)
12-
from rapids_singlecell.preprocessing._utils import _get_mean_var
1313

1414
from ..test_score_genes import _create_sparse_nan_matrix # noqa: TID252
1515

tests/dask/test_dask_pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from scipy import sparse
99

1010
import rapids_singlecell as rsc
11-
from rapids_singlecell._testing import (
11+
from testing.rapids_singlecell._helper import (
1212
as_dense_cupy_dask_array,
1313
as_sparse_cupy_dask_array,
1414
)

0 commit comments

Comments
 (0)