Skip to content

Commit 5470b07

Browse files
author
Donglai Wei
committed
Use independent shards for multi-GPU test
1 parent f6a1a66 commit 5470b07

2 files changed

Lines changed: 182 additions & 1 deletion

File tree

scripts/main.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,99 @@ def maybe_limit_test_devices(cfg: Config, datamodule) -> bool:
437437
return True
438438

439439

440+
def resolve_test_rank_shard_from_env() -> tuple[int | None, int | None]:
441+
"""Return rank/world_size for externally launched multi-process test jobs."""
442+
for rank_key, world_key in (("RANK", "WORLD_SIZE"), ("SLURM_PROCID", "SLURM_NTASKS")):
443+
rank_raw = os.environ.get(rank_key)
444+
world_raw = os.environ.get(world_key)
445+
if rank_raw is None or world_raw is None:
446+
continue
447+
try:
448+
rank = int(rank_raw)
449+
world_size = int(world_raw)
450+
except ValueError:
451+
continue
452+
if world_size > 1:
453+
return rank, world_size
454+
455+
return None, None
456+
457+
458+
def resolve_test_image_paths(cfg: Config) -> list[str]:
459+
"""Resolve test image paths from config for shard planning."""
460+
data_cfg = getattr(cfg, "data", None)
461+
test_image = getattr(getattr(data_cfg, "test", None), "image", None)
462+
if not test_image:
463+
return []
464+
465+
from connectomics.training.lightning.path_utils import expand_file_paths
466+
467+
try:
468+
return expand_file_paths(test_image)
469+
except Exception as exc:
470+
print(f" WARNING: Failed to resolve test_image paths for sharding: {exc}")
471+
return []
472+
473+
474+
def maybe_enable_independent_test_sharding(args, cfg: Config) -> bool:
475+
"""Run test as independent single-GPU shards instead of DDP when rank info is available."""
476+
requested_devices = int(getattr(cfg.system, "num_gpus", 0) or 0)
477+
if requested_devices <= 1:
478+
return False
479+
480+
shard_id = getattr(args, "shard_id", None)
481+
num_shards = getattr(args, "num_shards", None)
482+
source = None
483+
484+
if shard_id is not None and num_shards is not None and int(num_shards) > 1:
485+
source = "explicit shard arguments"
486+
else:
487+
test_image_paths = resolve_test_image_paths(cfg)
488+
if len(test_image_paths) <= 1:
489+
return False
490+
491+
shard_id, num_shards = resolve_test_rank_shard_from_env()
492+
if shard_id is None or num_shards is None:
493+
return False
494+
495+
args.shard_id = shard_id
496+
args.num_shards = num_shards
497+
source = "distributed launcher environment"
498+
499+
tta_cfg = getattr(getattr(cfg, "inference", None), "test_time_augmentation", None)
500+
if tta_cfg is not None and bool(getattr(tta_cfg, "distributed_sharding", False)):
501+
print(
502+
" WARNING: Disabling distributed TTA sharding for independent per-rank test sharding."
503+
)
504+
tta_cfg.distributed_sharding = False
505+
506+
cfg.system.num_gpus = 1 if torch.cuda.is_available() else 0
507+
print(
508+
" INFO: Independent multi-GPU test sharding enabled "
509+
f"({source}); each process will handle its own shard with no DDP communication."
510+
)
511+
return True
512+
513+
514+
def has_assigned_test_shard(cfg: Config, args) -> bool:
515+
"""Return True if the current shard has at least one test volume to process."""
516+
shard_id = getattr(args, "shard_id", None)
517+
num_shards = getattr(args, "num_shards", None)
518+
if shard_id is None or num_shards is None:
519+
return True
520+
521+
test_image_paths = resolve_test_image_paths(cfg)
522+
if not test_image_paths:
523+
return True
524+
525+
if test_image_paths[shard_id::num_shards]:
526+
return True
527+
528+
print(f" Shard {shard_id}/{num_shards} is empty, nothing to do.")
529+
print("[OK]Test completed successfully (empty shard).")
530+
return False
531+
532+
440533
def shard_test_datamodule(datamodule, shard_id: int, num_shards: int):
441534
"""Shard test volumes across machines.
442535
@@ -766,6 +859,11 @@ def main():
766859
print(f"Random seed set to: {cfg.system.seed}")
767860
seed_everything(cfg.system.seed, workers=True)
768861

862+
if args.mode == "test":
863+
maybe_enable_independent_test_sharding(args, cfg)
864+
if not has_assigned_test_shard(cfg, args):
865+
return
866+
769867
# Cache-only preflight path for test mode (can skip model/trainer/dataloader entirely).
770868
if try_cache_only_test_execution(cfg, args.mode, args.shard_id, args.num_shards):
771869
return
@@ -839,6 +937,17 @@ def main():
839937
# Re-resolve test-stage runtime overrides after tuning, including sentinels.
840938
cfg = resolve_test_stage_runtime(cfg)
841939

940+
if maybe_enable_independent_test_sharding(args, cfg):
941+
trainer = create_trainer(
942+
cfg,
943+
run_dir=run_dir,
944+
fast_dev_run=args.fast_dev_run,
945+
ckpt_path=ckpt_path,
946+
mode="test",
947+
)
948+
if not has_assigned_test_shard(cfg, args):
949+
return
950+
842951
# Create datamodule
843952
datamodule = create_datamodule(cfg, mode="test")
844953

@@ -882,7 +991,7 @@ def main():
882991

883992
trainer.test(
884993
model,
885-
datamodule=datamodule,
994+
datamodule,
886995
ckpt_path=test_ckpt_path,
887996
)
888997

tests/unit/test_main_runtime_stage_switch.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import argparse
22
from pathlib import Path
33

4+
import torch
5+
46
from connectomics.config import Config, save_config
57
from connectomics.config.schema.inference import EvaluationConfig
68
from connectomics.training.lightning.utils import setup_config
79
from scripts.main import (
810
_is_test_evaluation_enabled,
11+
has_assigned_test_shard,
912
maybe_limit_test_devices,
13+
maybe_enable_independent_test_sharding,
1014
resolve_test_stage_runtime,
1115
)
1216

@@ -30,6 +34,8 @@ def _make_args(config_path: Path, mode: str = "test"):
3034
tune_trials=None,
3135
nnunet_preprocess=False,
3236
overrides=[],
37+
shard_id=None,
38+
num_shards=None,
3339
)
3440

3541

@@ -106,3 +112,69 @@ def test_maybe_limit_test_devices_keeps_distributed_tta_sharding_for_single_volu
106112
assert changed is False
107113
assert cfg.system.num_gpus == 4
108114
assert cfg.inference.test_time_augmentation.distributed_sharding is True
115+
116+
117+
def test_maybe_enable_independent_test_sharding_uses_rank_env_for_multi_volume_tests(
118+
tmp_path, monkeypatch
119+
):
120+
cfg = Config()
121+
cfg.system.num_gpus = 4
122+
cfg.inference.test_time_augmentation.distributed_sharding = True
123+
args = _make_args(tmp_path / "config.yaml")
124+
125+
monkeypatch.setenv("SLURM_PROCID", "2")
126+
monkeypatch.setenv("SLURM_NTASKS", "4")
127+
monkeypatch.setattr("scripts.main.resolve_test_image_paths", lambda _cfg: ["a", "b", "c", "d"])
128+
129+
changed = maybe_enable_independent_test_sharding(args, cfg)
130+
131+
assert changed is True
132+
assert args.shard_id == 2
133+
assert args.num_shards == 4
134+
assert cfg.system.num_gpus == (1 if torch.cuda.is_available() else 0)
135+
assert cfg.inference.test_time_augmentation.distributed_sharding is False
136+
137+
138+
def test_maybe_enable_independent_test_sharding_uses_explicit_shard_args(tmp_path):
139+
cfg = Config()
140+
cfg.system.num_gpus = 4
141+
args = _make_args(tmp_path / "config.yaml")
142+
args.shard_id = 1
143+
args.num_shards = 4
144+
145+
changed = maybe_enable_independent_test_sharding(args, cfg)
146+
147+
assert changed is True
148+
assert cfg.system.num_gpus == (1 if torch.cuda.is_available() else 0)
149+
150+
151+
def test_maybe_enable_independent_test_sharding_skips_single_volume_tests(
152+
tmp_path, monkeypatch
153+
):
154+
cfg = Config()
155+
cfg.system.num_gpus = 4
156+
args = _make_args(tmp_path / "config.yaml")
157+
158+
monkeypatch.setenv("SLURM_PROCID", "0")
159+
monkeypatch.setenv("SLURM_NTASKS", "4")
160+
monkeypatch.setattr("scripts.main.resolve_test_image_paths", lambda _cfg: ["only_one"])
161+
162+
changed = maybe_enable_independent_test_sharding(args, cfg)
163+
164+
assert changed is False
165+
assert args.shard_id is None
166+
assert args.num_shards is None
167+
assert cfg.system.num_gpus == 4
168+
169+
170+
def test_has_assigned_test_shard_returns_false_for_empty_slice(
171+
tmp_path, monkeypatch
172+
):
173+
args = _make_args(tmp_path / "config.yaml")
174+
cfg = Config()
175+
args.shard_id = 3
176+
args.num_shards = 4
177+
178+
monkeypatch.setattr("scripts.main.resolve_test_image_paths", lambda _cfg: ["vol0"])
179+
180+
assert has_assigned_test_shard(cfg, args) is False

0 commit comments

Comments
 (0)