Skip to content

Commit af62e50

Browse files
author
Donglai Wei
committed
Fix tune trial timeout handling
1 parent ee82c46 commit af62e50

6 files changed

Lines changed: 153 additions & 10 deletions

File tree

connectomics/decoding/tuning/optuna_tuner.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import h5py
2929
import numpy as np
30+
import torch
3031
from omegaconf import DictConfig, OmegaConf
3132

3233
try:
@@ -335,8 +336,14 @@ def _trial_evaluation_worker(send_conn, evaluation_kind: str, payload: Dict[str,
335336

336337

337338
def _get_trial_process_context() -> mp.context.BaseContext:
338-
"""Prefer fork to avoid copying large prediction arrays when available."""
339-
for method in ("fork", "spawn"):
339+
"""Choose a multiprocessing start method for timeout-enforced trials."""
340+
methods = ("fork", "spawn")
341+
if torch.cuda.is_available() and torch.cuda.is_initialized():
342+
# Tune mode runs inference before Optuna; once CUDA is initialized,
343+
# forking the parent process is unsafe and can hang.
344+
methods = ("spawn", "fork")
345+
346+
for method in methods:
340347
try:
341348
return mp.get_context(method)
342349
except ValueError:
@@ -795,6 +802,12 @@ def optimize(self) -> optuna.Study:
795802

796803
metric = self.tune_cfg.optimization["single_objective"]["metric"]
797804
trial_timeout = self._get_trial_timeout_seconds()
805+
if timeout is not None and trial_timeout is None:
806+
logger.warning(
807+
"tune.timeout=%s limits the whole Optuna study, not one trial. "
808+
"Long WaterZ runs will still block unless tune.trial_timeout is set.",
809+
timeout,
810+
)
798811
logger.info(
799812
"Starting Optuna optimization: %s | Trials: %s | Metric: %s | "
800813
"Direction: %s | Trial timeout: %s",

connectomics/training/lightning/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ def parse_args():
157157
default=None,
158158
help="Number of Optuna trials (overrides config, use with --mode tune or tune-test)",
159159
)
160+
tune_group.add_argument(
161+
"--tune-timeout",
162+
type=int,
163+
default=None,
164+
help="Whole-study Optuna timeout in seconds (overrides tune.timeout)",
165+
)
166+
tune_group.add_argument(
167+
"--tune-trial-timeout",
168+
type=int,
169+
default=None,
170+
help="Per-trial tuning timeout in seconds (overrides tune.trial_timeout)",
171+
)
160172
parser.add_argument(
161173
"overrides",
162174
nargs="*",
@@ -217,6 +229,28 @@ def setup_config(args) -> Config:
217229
# Resolve data paths on merged runtime data section.
218230
cfg = resolve_data_paths(cfg)
219231

232+
if cfg.tune is not None:
233+
if args.tune_trials is not None:
234+
logger.info("Overriding tune.n_trials: %s -> %s", cfg.tune.n_trials, args.tune_trials)
235+
cfg.tune.n_trials = args.tune_trials
236+
237+
if args.tune_timeout is not None:
238+
logger.info("Overriding tune.timeout: %s -> %s", cfg.tune.timeout, args.tune_timeout)
239+
cfg.tune.timeout = args.tune_timeout
240+
241+
if args.tune_trial_timeout is not None:
242+
logger.info(
243+
"Overriding tune.trial_timeout: %s -> %s",
244+
cfg.tune.trial_timeout,
245+
args.tune_trial_timeout,
246+
)
247+
cfg.tune.trial_timeout = args.tune_trial_timeout
248+
elif any(
249+
value is not None
250+
for value in (args.tune_trials, args.tune_timeout, args.tune_trial_timeout)
251+
):
252+
logger.warning("Ignoring --tune-* CLI overrides because the config has no tune section")
253+
220254
# Override max_epochs if --reset-max-epochs is specified
221255
if args.reset_max_epochs is not None:
222256
logger.info(

tests/unit/test_lit_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from connectomics.config import Config, save_config
8+
from connectomics.config.schema.stages import TuneConfig
89
from connectomics.training.lightning.data_factory import _calculate_validation_steps_per_epoch
910
from connectomics.training.lightning.path_utils import (
1011
expand_file_paths as canonical_expand_file_paths,
@@ -27,6 +28,9 @@ def _make_args(
2728
fast_dev_run: int = 0,
2829
mode: str = "train",
2930
nnunet_preprocess: bool = False,
31+
tune_timeout: int | None = None,
32+
tune_trial_timeout: int | None = None,
33+
tune_trials: int | None = None,
3034
):
3135
return argparse.Namespace(
3236
config=str(config_path),
@@ -42,7 +46,9 @@ def _make_args(
4246
external_prefix=None,
4347
params=None,
4448
param_source=None,
45-
tune_trials=None,
49+
tune_trials=tune_trials,
50+
tune_timeout=tune_timeout,
51+
tune_trial_timeout=tune_trial_timeout,
4652
nnunet_preprocess=nnunet_preprocess,
4753
overrides=overrides or [],
4854
)
@@ -84,6 +90,27 @@ def test_setup_config_enables_nnunet_preprocess_from_cli_switch(tmp_path):
8490
assert updated.data.nnunet_preprocessing.enabled is True
8591

8692

93+
def test_setup_config_applies_tune_timeout_cli_overrides(tmp_path):
94+
cfg = Config()
95+
cfg.tune = TuneConfig()
96+
97+
cfg_path = tmp_path / "config.yaml"
98+
save_config(cfg, cfg_path)
99+
100+
args = _make_args(
101+
cfg_path,
102+
mode="tune",
103+
tune_trials=17,
104+
tune_timeout=3600,
105+
tune_trial_timeout=300,
106+
)
107+
updated = setup_config(args)
108+
109+
assert updated.tune.n_trials == 17
110+
assert updated.tune.timeout == 3600
111+
assert updated.tune.trial_timeout == 300
112+
113+
87114
def test_expand_file_paths_handles_globs_and_lists(tmp_path):
88115
data_dir = tmp_path / "data"
89116
data_dir.mkdir()

tests/unit/test_main_cli_contract.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,13 @@ def test_parse_args_demo_mode_requires_no_config(monkeypatch):
5959
args = _parse_with_argv(monkeypatch, ["--demo"])
6060
assert args.demo is True
6161
assert args.config is None
62+
63+
64+
def test_parse_args_accepts_tune_timeout_flags(monkeypatch):
65+
args = _parse_with_argv(
66+
monkeypatch,
67+
["--tune-timeout", "3600", "--tune-trial-timeout", "300"],
68+
)
69+
70+
assert args.tune_timeout == 3600
71+
assert args.tune_trial_timeout == 300

tests/unit/test_main_runtime_stage_switch.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from scripts.main import (
1010
_is_test_evaluation_enabled,
1111
has_assigned_test_shard,
12-
maybe_limit_test_devices,
1312
maybe_enable_independent_test_sharding,
13+
maybe_limit_test_devices,
1414
resolve_test_stage_runtime,
1515
)
1616

@@ -32,6 +32,8 @@ def _make_args(config_path: Path, mode: str = "test"):
3232
params=None,
3333
param_source=None,
3434
tune_trials=None,
35+
tune_timeout=None,
36+
tune_trial_timeout=None,
3537
nnunet_preprocess=False,
3638
overrides=[],
3739
shard_id=None,
@@ -163,9 +165,7 @@ def test_maybe_enable_independent_test_sharding_uses_explicit_shard_args(tmp_pat
163165
assert cfg.system.num_gpus == (1 if torch.cuda.is_available() else 0)
164166

165167

166-
def test_maybe_enable_independent_test_sharding_skips_single_volume_tests(
167-
tmp_path, monkeypatch
168-
):
168+
def test_maybe_enable_independent_test_sharding_skips_single_volume_tests(tmp_path, monkeypatch):
169169
cfg = Config()
170170
cfg.system.num_gpus = 4
171171
args = _make_args(tmp_path / "config.yaml")
@@ -182,9 +182,7 @@ def test_maybe_enable_independent_test_sharding_skips_single_volume_tests(
182182
assert cfg.system.num_gpus == 4
183183

184184

185-
def test_has_assigned_test_shard_returns_false_for_empty_slice(
186-
tmp_path, monkeypatch
187-
):
185+
def test_has_assigned_test_shard_returns_false_for_empty_slice(tmp_path, monkeypatch):
188186
args = _make_args(tmp_path / "config.yaml")
189187
cfg = Config()
190188
args.shard_id = 3

tests/unit/test_optuna_tuner.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from connectomics.decoding.tuning.optuna_tuner import (
1111
OptunaDecodingTuner,
1212
TrialEvaluationTimeoutError,
13+
_get_trial_process_context,
1314
load_and_apply_best_params,
1415
run_tuning,
1516
)
@@ -361,3 +362,63 @@ def _raise_timeout(_evaluation_kind, _payload):
361362
assert trial.user_attrs["timed_out"] is True
362363
assert trial.user_attrs["timeout_stage"] == "waterz_batch"
363364
assert trial.user_attrs["trial_timeout"] == 30.0
365+
366+
367+
def test_get_trial_process_context_prefers_spawn_after_cuda_init(monkeypatch):
368+
observed = []
369+
370+
class _DummyContext:
371+
pass
372+
373+
def _fake_get_context(method=None):
374+
observed.append(method)
375+
if method == "spawn":
376+
return _DummyContext()
377+
raise ValueError(f"unsupported: {method}")
378+
379+
monkeypatch.setattr(
380+
"connectomics.decoding.tuning.optuna_tuner.torch.cuda.is_available",
381+
lambda: True,
382+
)
383+
monkeypatch.setattr(
384+
"connectomics.decoding.tuning.optuna_tuner.torch.cuda.is_initialized",
385+
lambda: True,
386+
)
387+
monkeypatch.setattr(
388+
"connectomics.decoding.tuning.optuna_tuner.mp.get_context", _fake_get_context
389+
)
390+
391+
ctx = _get_trial_process_context()
392+
393+
assert isinstance(ctx, _DummyContext)
394+
assert observed == ["spawn"]
395+
396+
397+
def test_get_trial_process_context_prefers_fork_without_cuda_init(monkeypatch):
398+
observed = []
399+
400+
class _DummyContext:
401+
pass
402+
403+
def _fake_get_context(method=None):
404+
observed.append(method)
405+
if method == "fork":
406+
return _DummyContext()
407+
raise ValueError(f"unsupported: {method}")
408+
409+
monkeypatch.setattr(
410+
"connectomics.decoding.tuning.optuna_tuner.torch.cuda.is_available",
411+
lambda: False,
412+
)
413+
monkeypatch.setattr(
414+
"connectomics.decoding.tuning.optuna_tuner.torch.cuda.is_initialized",
415+
lambda: False,
416+
)
417+
monkeypatch.setattr(
418+
"connectomics.decoding.tuning.optuna_tuner.mp.get_context", _fake_get_context
419+
)
420+
421+
ctx = _get_trial_process_context()
422+
423+
assert isinstance(ctx, _DummyContext)
424+
assert observed == ["fork"]

0 commit comments

Comments
 (0)