Skip to content

Commit 9de11b1

Browse files
flying-sheepilan-goldpre-commit-ci[bot]
authored
feat: support np.random.Generator (#3983)
Co-authored-by: Ilan Gold <ilanbassgold@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8c60cd2 commit 9de11b1

42 files changed

Lines changed: 937 additions & 615 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

benchmarks/benchmarks/preprocessing_counts.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
from inspect import signature
89
from itertools import product
910
from typing import TYPE_CHECKING
1011

@@ -15,10 +16,19 @@
1516
from ._utils import get_count_dataset
1617

1718
if TYPE_CHECKING:
19+
from typing import Any
20+
1821
from ._utils import Dataset, KeyCount
1922

2023

21-
# ASV suite
24+
def cache_adata(dataset: Dataset, layer: KeyCount) -> None:
25+
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
26+
adata, batch_key = get_count_dataset(dataset, layer=layer)
27+
assert "lop1p" not in adata.uns
28+
adata.uns["batch_key"] = batch_key
29+
adata.write_h5ad(f"{dataset}_{layer}.h5ad")
30+
31+
2232
class PreprocessingCountsSuite: # noqa: D101
2333
params: tuple[list[Dataset], list[KeyCount]] = (
2434
["pbmc68k_reduced", "pbmc3k"],
@@ -27,12 +37,8 @@ class PreprocessingCountsSuite: # noqa: D101
2737
param_names = ("dataset", "layer")
2838

2939
def setup_cache(self) -> None:
30-
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
3140
for dataset, layer in product(*self.params):
32-
adata, batch_key = get_count_dataset(dataset, layer=layer)
33-
assert "lop1p" not in adata.uns
34-
adata.uns["batch_key"] = batch_key
35-
adata.write_h5ad(f"{dataset}_{layer}.h5ad")
41+
cache_adata(dataset, layer)
3642

3743
def setup(self, dataset, layer) -> None:
3844
self.adata = ad.read_h5ad(f"{dataset}_{layer}.h5ad")
@@ -65,6 +71,40 @@ def peakmem_scrublet(self, *_) -> None:
6571
# sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3_paper")
6672

6773

74+
class PreprocessingCountsRngSuite: # noqa: D101
75+
params: tuple[list[Dataset], list[str], list[str]] = (
76+
["pbmc68k_reduced", "pbmc3k"],
77+
["rng", "random_state"],
78+
)
79+
param_names = ("dataset", "layer")
80+
81+
def setup_cache(self) -> None:
82+
for dataset in self.params[0]:
83+
cache_adata(dataset, "counts")
84+
85+
def setup(self, dataset, rng_arg) -> None:
86+
if (
87+
rng_arg == "rng"
88+
and "rng" not in signature(sc.pp.downsample_counts).parameters
89+
):
90+
raise NotImplementedError
91+
self.adata = ad.read_h5ad(f"{dataset}_counts.h5ad")
92+
self.rng_kw: Any = {rng_arg: 0}
93+
self.total = self.adata.X.sum() / 10
94+
95+
def time_downsample_per_cell(self, *_) -> None:
96+
sc.pp.downsample_counts(self.adata, counts_per_cell=3, **self.rng_kw)
97+
98+
def peakmem_downsample_per_cell(self, *_) -> None:
99+
sc.pp.downsample_counts(self.adata, counts_per_cell=3, **self.rng_kw)
100+
101+
def time_downsample_total(self, *_) -> None:
102+
sc.pp.downsample_counts(self.adata, total_counts=self.total, **self.rng_kw)
103+
104+
def peakmem_downsample_total(self, *_) -> None:
105+
sc.pp.downsample_counts(self.adata, total_counts=self.total, **self.rng_kw)
106+
107+
68108
class FastSuite:
69109
"""Suite for fast preprocessing operations."""
70110

@@ -75,11 +115,8 @@ class FastSuite:
75115
param_names = ("dataset", "layer")
76116

77117
def setup_cache(self) -> None:
78-
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
79118
for dataset, layer in product(*self.params):
80-
adata, _ = get_count_dataset(dataset, layer=layer)
81-
assert "lop1p" not in adata.uns
82-
adata.write_h5ad(f"{dataset}_{layer}.h5ad")
119+
cache_adata(dataset, layer)
83120

84121
def setup(self, dataset, layer) -> None:
85122
self.adata = ad.read_h5ad(f"{dataset}_{layer}.h5ad")

docs/release-notes/3983.feat.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add support for {class}`numpy.random.Generator` to all functions previously accepting a `random_state` parameter {smaller}`P Angerer`

hatch.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ overrides.matrix.deps.python = [
3636
{ if = [ "low-vers" ], value = "3.12" },
3737
]
3838
overrides.matrix.deps.extra-dependencies = [
39+
{ if = [ "stable" ], value = "scipy>=1.17" },
3940
{ if = [ "pre" ], value = "anndata @ git+https://github.com/scverse/anndata.git" },
4041
{ if = [ "pre" ], value = "pandas>=3rc0" },
4142
]

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ filterwarnings = [
278278
"ignore:The `igraph` implementation of leiden clustering:UserWarning",
279279
# everybody uses this zarr 3 feature, including us, XArray, lots of data out there …
280280
"ignore:Consolidated metadata is currently not part:UserWarning",
281+
# joblib fallback to serial mode in restricted multiprocessing environments
282+
"ignore:.*joblib will operate in serial mode:UserWarning",
281283
]
282284

283285
[tool.coverage]

src/scanpy/_docs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Shared docstrings for general parameters."""
2+
3+
from __future__ import annotations
4+
5+
__all__ = ["doc_rng"]
6+
7+
doc_rng = """\
8+
rng
9+
Random number generation to control stochasticity.
10+
11+
If a type:`SeedLike` value, it’s used to seed a new random number generator;
12+
If a :class:`numpy.random.Generator`, `rng`’s state will be directly advanced;
13+
If :data:`None`, a non-reproducible random number generator is used.
14+
See :func:`numpy.random.default_rng` for more details.
15+
16+
The default value matches legacy scanpy behavior and will change to `None` in scanpy 2.0.\
17+
"""

src/scanpy/_utils/random.py

Lines changed: 124 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,28 @@
77
from typing import TYPE_CHECKING
88

99
import numpy as np
10-
from sklearn.utils import check_random_state
10+
from sklearn.utils.random import check_random_state
1111

1212
from . import ensure_igraph
1313

1414
if TYPE_CHECKING:
15-
from collections.abc import Generator
15+
from collections.abc import Callable, Generator
16+
from typing import Self
1617

18+
from numpy.random import BitGenerator
1719
from numpy.typing import NDArray
1820

1921

2022
__all__ = [
2123
"RNGLike",
2224
"SeedLike",
2325
"_LegacyRandom",
26+
"_LegacyRng",
27+
"_accepts_legacy_random_state",
28+
"_if_legacy_apply_global",
29+
"_legacy_random_state",
30+
"_set_igraph_rng",
2431
"ith_k_tuple",
25-
"legacy_numpy_gen",
2632
"random_k_tuples",
2733
"random_str",
2834
]
@@ -38,34 +44,38 @@
3844

3945

4046
class _RNGIgraph:
41-
"""Random number generator for igraph so global seed is not changed.
47+
"""Random number generator for igraph so global random state is not changed.
4248
4349
See :func:`igraph.set_random_number_generator` for the requirements.
4450
"""
4551

46-
def __init__(self, random_state: int | np.random.RandomState = 0) -> None:
47-
self._rng = check_random_state(random_state)
52+
def __init__(self, rng: SeedLike | RNGLike | None) -> None:
53+
self._rng = np.random.default_rng(rng)
4854

4955
def getrandbits(self, k: int) -> int:
50-
return self._rng.tomaxint() & ((1 << k) - 1)
56+
if isinstance(self._rng, _LegacyRng):
57+
i = self._rng.state.tomaxint()
58+
else:
59+
lims = np.iinfo(np.uint64)
60+
i = int(self._rng.integers(0, lims.max, dtype=np.uint64, endpoint=True))
61+
return i & ((1 << k) - 1)
5162

52-
def randint(self, a: int, b: int) -> int:
53-
return self._rng.randint(a, b + 1)
63+
def randint(self, a: int, b: int) -> np.int64:
64+
"""Can’t use `endpoint` here as _LegacyRng doesn’t support it."""
65+
return self._rng.integers(a, b + 1)
5466

5567
def __getattr__(self, attr: str):
5668
return getattr(self._rng, "normal" if attr == "gauss" else attr)
5769

5870

5971
@contextmanager
60-
def set_igraph_random_state(
61-
random_state: int | np.random.RandomState,
62-
) -> Generator[None, None, None]:
72+
def _set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]:
6373
ensure_igraph()
6474
import igraph
6575

66-
rng = _RNGIgraph(random_state)
76+
ig_rng = _RNGIgraph(rng)
6777
try:
68-
igraph.set_random_number_generator(rng)
78+
igraph.set_random_number_generator(ig_rng)
6979
yield None
7080
finally:
7181
igraph.set_random_number_generator(random)
@@ -76,42 +86,123 @@ def set_igraph_random_state(
7686
###################################
7787

7888

79-
def legacy_numpy_gen(
80-
random_state: _LegacyRandom | None = None,
81-
) -> np.random.Generator:
82-
"""Return a random generator that behaves like the legacy one."""
83-
if random_state is not None:
84-
if isinstance(random_state, np.random.RandomState):
85-
np.random.set_state(random_state.get_state(legacy=False))
86-
return _FakeRandomGen(random_state)
87-
np.random.seed(random_state)
88-
return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator()))
89+
class _LegacyRng(np.random.Generator):
90+
"""A `Generator` that wraps a legacy `RandomState` instance.
8991
92+
To behave like a `RandomState`, it’s not enough to just use a MT19937 `bit_generator`
93+
(as in `Generator(RandomState(seed).bit_generator)`),
94+
so instead this hack uses the exact same random numbers as `RandomState(seed)`.
95+
"""
96+
97+
arg: _LegacyRandom
98+
state: np.random.RandomState
9099

91-
class _FakeRandomGen(np.random.Generator):
92-
_state: np.random.RandomState
100+
def __init__(
101+
self, arg: _LegacyRandom, state: np.random.RandomState | None = None
102+
) -> None:
103+
self.arg = arg
104+
self.state = check_random_state(arg) if state is None else state
93105

94-
def __init__(self, random_state: np.random.RandomState) -> None:
95-
self._state = random_state
106+
@property
107+
def bit_generator(self) -> BitGenerator:
108+
msg = "A _LegacyRng instance has no `bit_generator` attribute."
109+
raise AttributeError(msg)
110+
111+
@classmethod
112+
def wrap_global(
113+
cls,
114+
arg: _LegacyRandom = None,
115+
state: np.random.RandomState | None = None,
116+
) -> Self:
117+
"""Create a generator that wraps the global `RandomState` backing the legacy `np.random` functions."""
118+
if arg is not None:
119+
if isinstance(arg, np.random.RandomState):
120+
np.random.set_state(arg.get_state(legacy=False))
121+
return _LegacyRng(arg, state)
122+
np.random.seed(arg)
123+
return _LegacyRng(arg, np.random.RandomState(np.random.get_bit_generator()))
124+
125+
def spawn(self, n_children: int) -> list[Self]:
126+
"""Return `self` `n_children` times.
127+
128+
In a real generator, the spawned children are independent,
129+
but for backwards compatibility we return the same instance so that its internal state is advanced by each child.
130+
"""
131+
return [self] * n_children
96132

97133
@classmethod
98134
def _delegate(cls) -> None:
135+
names = dict(integers="randint")
99136
for name, meth in np.random.Generator.__dict__.items():
100-
if name.startswith("_") or not callable(meth):
137+
if name.startswith("_") or not callable(meth) or name in cls.__dict__:
101138
continue
102139

103140
def mk_wrapper(name: str, meth):
104141
# Old pytest versions try to run the doctests
105142
@wraps(meth, assigned=set(WRAPPER_ASSIGNMENTS) - {"__doc__"})
106-
def wrapper(self: _FakeRandomGen, *args, **kwargs):
107-
return getattr(self._state, name)(*args, **kwargs)
143+
def wrapper(self: _LegacyRng, *args, **kwargs):
144+
return getattr(self.state, name)(*args, **kwargs)
108145

109146
return wrapper
110147

111-
setattr(cls, name, mk_wrapper(name, meth))
148+
setattr(cls, names.get(name, name), mk_wrapper(name, meth))
149+
150+
151+
_LegacyRng._delegate()
152+
153+
154+
def _if_legacy_apply_global(rng: np.random.Generator, /) -> np.random.Generator:
155+
"""Wrap the global legacy RNG if `rng` is a `_LegacyRng`.
156+
157+
This is used where our code used to call `np.random.seed()`.
158+
It’s a no-op if `rng` is not a `_LegacyRng`.
159+
"""
160+
if not isinstance(rng, _LegacyRng):
161+
return rng
162+
163+
return _LegacyRng.wrap_global(rng.arg, rng.state)
164+
112165

166+
def _legacy_random_state(
167+
rng: SeedLike | RNGLike | None, /, *, always_state: bool = False
168+
) -> _LegacyRandom:
169+
"""Convert a np.random.Generator into a legacy `random_state` argument.
170+
171+
If `rng` is already a `_LegacyRng`, return its original `arg` attribute.
172+
"""
173+
if isinstance(rng, _LegacyRng):
174+
return rng.state if always_state else rng.arg
175+
[bitgen] = np.random.default_rng(rng).bit_generator.spawn(1)
176+
return np.random.RandomState(bitgen)
177+
178+
179+
def _accepts_legacy_random_state[**P, R](
180+
random_state_default: _LegacyRandom, /
181+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
182+
"""Make a function accept `random_state: _LegacyRandom` and pass it as `rng`.
183+
184+
If the decorated function is called with a `random_state` argument,
185+
it’ll be wrapped in a `_LegacyRng`.
186+
Passing both `rng` and `random_state` at the same time is an error.
187+
If neither is given, `random_state_default` is used.
188+
"""
113189

114-
_FakeRandomGen._delegate()
190+
def decorator(func: Callable[P, R]) -> Callable[P, R]:
191+
@wraps(func)
192+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
193+
match "random_state" in kwargs, "rng" in kwargs:
194+
case True, True:
195+
msg = "Specify at most one of `rng` and `random_state`."
196+
raise TypeError(msg)
197+
case True, False:
198+
kwargs["rng"] = _LegacyRng(kwargs.pop("random_state"))
199+
case False, False:
200+
kwargs["rng"] = _LegacyRng(random_state_default)
201+
return func(*args, **kwargs)
202+
203+
return wrapper
204+
205+
return decorator
115206

116207

117208
###################

0 commit comments

Comments
 (0)