Skip to content

Commit 91193e4

Browse files
authored
Implement a basic working ModelTuner API with Ray Tune (#1785)
* Add hyperopt and ray as optional dependencies * Add basic module structure * Add classproperty and dependencies decorators * Add file with default search spaces and helpers * Add skeleton ModelTuner API * Add Tunable type and notebook util function * Add TunerManager API, update __init__s * Add skeleton functions for TunerManager * Implement TunerManager validation functions * Address scvi, ray import kernel crashes * Update CI workflow with autotune deps * Implement dummy SCVI autotune interface * Add sanity autotune test * Retrigger checks * Update .github/workflows/test.yml * Potential fix for CUDA forked subprocess error * Potential fix for CUDA forked process error * Force spawn subprocesses * Add optional pytest mark, change to forkserver default * Try import ray on init * Update docs with autotune * Update docs and docstrings for autotune * Update docs, include more basic tests for autotune * Faster tests, validate anndata earlier * Fix missing kwargs in test
1 parent b2784f7 commit 91193e4

18 files changed

Lines changed: 909 additions & 7 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
id: dependencies
3838
run: |
3939
pip install pytest-cov
40-
pip install .[dev,pymde]
40+
pip install .[dev,pymde,autotune]
4141
4242
# Following checks are independent and are run even if one fails
4343
- name: Lint with flake8

docs/api/developer.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,24 @@ TrainingPlans define train/test/val optimization steps for modules.
237237
238238
```
239239

240+
## Model hyperparameter autotuning
241+
242+
`scvi-tools` supports automatic model hyperparameter tuning using [Ray Tune]. These
243+
classes allow for new model classes to be easily integrated with the module.
244+
245+
```{eval-rst}
246+
.. currentmodule:: scvi
247+
```
248+
249+
```{eval-rst}
250+
.. autosummary::
251+
:toctree: reference/
252+
:nosignatures:
253+
254+
autotune.TunerManager
255+
autotune.Tunable
256+
```
257+
240258
## Utilities
241259

242260
```{eval-rst}
@@ -254,3 +272,5 @@ Utility functions used by scvi-tools.
254272
utils.setup_anndata_dsp
255273
utils.attrdict
256274
```
275+
276+
[ray tune]: https://docs.ray.io/en/latest/tune/index.html

docs/api/user.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,22 @@ Here we maintain a few package specific utilities for feature selection, etc.
9999
data.organize_multiome_anndatas
100100
```
101101

102+
```{eval-rst}
103+
.. currentmodule:: scvi
104+
```
105+
106+
## Model hyperparameter autotuning
107+
108+
`scvi-tools` supports automatic model hyperparameter tuning using [Ray Tune].
109+
110+
```{eval-rst}
111+
.. autosummary::
112+
:toctree: reference/
113+
:nosignatures:
114+
115+
autotune.ModelTuner
116+
```
117+
102118
## Utilities
103119

104120
Here we maintain miscellaneous general methods.
@@ -126,3 +142,4 @@ An instance of the {class}`~scvi._settings.ScviConfig` is available as `scvi.set
126142
[anndata]: https://anndata.readthedocs.io/en/stable/
127143
[scanpy]: https://scanpy.readthedocs.io/en/stable/index.html
128144
[utilities]: https://scanpy.readthedocs.io/en/stable/api/index.html#reading
145+
[ray tune]: https://docs.ray.io/en/latest/tune/index.html

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
"jax": ("https://jax.readthedocs.io/en/latest/", None),
9393
"ml_collections": ("https://ml-collections.readthedocs.io/en/latest/", None),
9494
"mudata": ("https://mudata.readthedocs.io/en/latest/", None),
95+
"ray": ("https://docs.ray.io/en/latest/", None),
9596
}
9697

9798

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ flake8 = {version = ">=3.7.7", optional = true}
4040
flax = "*"
4141
furo = {version = ">=2022.2.14.1", optional = true}
4242
h5py = ">=2.9.0"
43+
hyperopt = {version = ">=0.2", optional = true}
4344
importlib-metadata = {version = ">1.0", python = "<3.8"}
4445
ipython = {version = ">=7.20", optional = true, python = ">=3.7"}
4546
ipywidgets = "*"
@@ -68,6 +69,7 @@ pytest = {version = ">=4.4", optional = true}
6869
python = ">=3.7,<4.0"
6970
python-igraph = {version = "*", optional = true}
7071
pytorch-lightning = ">=1.8.0,<1.9"
72+
ray = {extras = ["tune"], version = ">=2.1.0", optional = true}
7173
rich = ">=9.1.0"
7274
scanpy = {version = ">=1.6", optional = true}
7375
scikit-learn = ">=0.21.2"
@@ -100,6 +102,7 @@ docs = [
100102
"sphinxcontrib-bibtex",
101103
"myst-parser",
102104
]
105+
autotune = ["hyperopt", "ray", "ipython"]
103106
pymde = ["pymde"]
104107
tutorials = ["scanpy", "leidenalg", "python-igraph", "loompy", "scikit-misc", "pynndescent", "pymde"]
105108

scvi/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
# Set default logging handler to avoid logging with logging.lastResort logger.
44
import logging
55

6+
try:
7+
# necessary as importing scvi after ray causes kernel crash
8+
from ray import tune # noqa
9+
except ImportError:
10+
pass
11+
612
from ._constants import REGISTRY_KEYS
713
from ._settings import settings
814

915
# this import needs to come after prior imports to prevent circular import
10-
from . import data, model, external, utils
16+
from . import autotune, data, model, external, utils
1117

1218
# https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
1319
# https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
@@ -25,4 +31,12 @@
2531
scvi_logger = logging.getLogger("scvi")
2632
scvi_logger.propagate = False
2733

28-
__all__ = ["settings", "REGISTRY_KEYS", "data", "model", "external", "utils"]
34+
__all__ = [
35+
"settings",
36+
"REGISTRY_KEYS",
37+
"autotune",
38+
"data",
39+
"model",
40+
"external",
41+
"utils",
42+
]

scvi/_decorators.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from functools import wraps
2+
from typing import Callable, List, Union
3+
4+
5+
class classproperty:
6+
"""
7+
Read-only class property decorator.
8+
9+
Source: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
10+
"""
11+
12+
def __init__(self, f):
13+
self.f = f
14+
15+
def __get__(self, obj, owner):
16+
return self.f(owner)
17+
18+
19+
def dependencies(packages: Union[str, List[str]]) -> Callable:
20+
"""
21+
Decorator to check for dependencies.
22+
23+
Parameters
24+
----------
25+
packages
26+
A string or list of strings of packages to check for.
27+
"""
28+
if isinstance(packages, str):
29+
packages = [packages]
30+
31+
def decorator(fn: Callable) -> Callable:
32+
@wraps(fn)
33+
def wrapper(*args, **kwargs):
34+
try:
35+
import importlib
36+
37+
for package in packages:
38+
importlib.import_module(package)
39+
except ImportError:
40+
raise ImportError(
41+
f"Please install {packages} to use this functionality."
42+
)
43+
return fn(*args, **kwargs)
44+
45+
return wrapper
46+
47+
return decorator

scvi/autotune/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ._manager import TunerManager
2+
from ._tuner import ModelTuner
3+
from ._types import Tunable
4+
5+
__all__ = ["ModelTuner", "Tunable", "TunerManager"]

scvi/autotune/_defaults.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
2+
3+
from scvi import model
4+
from scvi.module.base import BaseModuleClass, JaxBaseModuleClass, PyroBaseModuleClass
5+
from scvi.train import TrainRunner
6+
7+
# colors for rich table columns
8+
COLORS = [
9+
"dodger_blue1",
10+
"dark_violet",
11+
"green",
12+
"dark_orange",
13+
]
14+
15+
# default rich table column kwargs
16+
COLUMN_KWARGS = {
17+
"justify": "center",
18+
"no_wrap": True,
19+
"overflow": "fold",
20+
}
21+
22+
# maps classes to the type of hyperparameters they use
23+
TUNABLE_TYPES = {
24+
"model": [
25+
BaseModuleClass,
26+
JaxBaseModuleClass,
27+
PyroBaseModuleClass,
28+
],
29+
"train": [
30+
LightningDataModule,
31+
Trainer,
32+
TrainRunner,
33+
],
34+
"train_plan": [
35+
LightningModule,
36+
],
37+
}
38+
39+
# supported model classes
40+
SUPPORTED = [model.SCVI]
41+
42+
# default hyperparameter search spaces for each model class
43+
DEFAULTS = {
44+
model.SCVI: {
45+
"n_hidden": {"fn": "choice", "args": [[64, 128]]},
46+
}
47+
}

0 commit comments

Comments
 (0)