Skip to content

Commit 6d70585

Browse files
author
Donglai Wei
committed
Fix multi-GPU test sharding
1 parent c086af2 commit 6d70585

9 files changed

Lines changed: 232 additions & 5 deletions

File tree

connectomics/inference/tta.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,10 @@ def _apply_distributed_reduction(
604604
) -> Optional[torch.Tensor]:
605605
"""Reduce ensemble results across DDP ranks. Returns None on non-zero ranks."""
606606
_is_dist, rank, _world_size = self._distributed_context()
607+
self._validate_distributed_reduction_shape(
608+
ensemble_result,
609+
reduction_device=reduction_device,
610+
)
607611

608612
if ensemble_mode == "mean":
609613
reduced_sum = self._reduce_cpu_tensor_to_rank_zero(
@@ -640,6 +644,53 @@ def _apply_distributed_reduction(
640644

641645
return None # non-zero ranks
642646

647+
def _validate_distributed_reduction_shape(
648+
self,
649+
ensemble_result: torch.Tensor,
650+
*,
651+
reduction_device: torch.device,
652+
) -> None:
653+
"""Fail fast when DDP ranks try to reduce different TTA prediction shapes."""
654+
is_dist, _rank, world_size = self._distributed_context()
655+
if not is_dist or world_size <= 1:
656+
return
657+
658+
max_dims = 6
659+
if ensemble_result.ndim > max_dims:
660+
raise RuntimeError(
661+
"Distributed TTA shape validation only supports tensors with up to "
662+
f"{max_dims} dimensions, got shape {tuple(ensemble_result.shape)}."
663+
)
664+
665+
shape_info = torch.zeros(max_dims + 1, device=reduction_device, dtype=torch.int64)
666+
shape_info[0] = int(ensemble_result.ndim)
667+
if ensemble_result.ndim:
668+
shape_info[1 : 1 + ensemble_result.ndim] = torch.tensor(
669+
tuple(int(v) for v in ensemble_result.shape),
670+
device=reduction_device,
671+
dtype=torch.int64,
672+
)
673+
674+
gathered = [torch.empty_like(shape_info) for _ in range(world_size)]
675+
torch.distributed.all_gather(gathered, shape_info)
676+
677+
shapes: list[tuple[int, ...]] = []
678+
for gathered_shape in gathered:
679+
ndim = int(gathered_shape[0].item())
680+
shapes.append(tuple(int(v.item()) for v in gathered_shape[1 : 1 + ndim]))
681+
682+
if any(shape != shapes[0] for shape in shapes[1:]):
683+
shape_summary = ", ".join(
684+
f"rank {rank_idx}: {shape}" for rank_idx, shape in enumerate(shapes)
685+
)
686+
raise RuntimeError(
687+
"Distributed TTA sharding requires every DDP rank to reduce predictions "
688+
f"with the same shape, got {shape_summary}. This usually means multiple "
689+
"test volumes were sharded across ranks; disable "
690+
"`inference.test_time_augmentation.distributed_sharding` for multi-volume "
691+
"tests."
692+
)
693+
643694
def _apply_mask_to_result(
644695
self,
645696
ensemble_result: torch.Tensor,

connectomics/training/lightning/data.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from monai.data import CacheDataset, Dataset
1616
from monai.transforms import Compose
17-
from torch.utils.data import DataLoader
17+
from torch.utils.data import DataLoader, Sampler
1818

1919

2020
class ConnectomicsDataModule(pl.LightningDataModule):
@@ -125,21 +125,26 @@ def val_dataloader(self):
125125
return self._create_dataloader(self.val_dataset, shuffle=False)
126126

127127
def test_dataloader(self):
128+
sampler = None
129+
if self.test_dataset is not None and _is_distributed_evaluation_active():
130+
sampler = DistributedEvaluationSampler(self.test_dataset)
128131
return self._create_dataloader(
129132
self.test_dataset,
130133
shuffle=False,
131134
collate_fn=collate_dict_list,
135+
sampler=sampler,
132136
)
133137

134-
def _create_dataloader(self, dataset, shuffle, collate_fn=None):
138+
def _create_dataloader(self, dataset, shuffle, collate_fn=None, sampler=None):
135139
if dataset is None:
136140
return None
137141
if collate_fn is None:
138142
collate_fn = collate_dict
139143
return DataLoader(
140144
dataset=dataset,
141145
batch_size=self.batch_size,
142-
shuffle=shuffle,
146+
shuffle=shuffle if sampler is None else False,
147+
sampler=sampler,
143148
num_workers=self.num_workers,
144149
pin_memory=self.pin_memory,
145150
persistent_workers=(self.persistent_workers and self.num_workers > 0),
@@ -189,6 +194,45 @@ def __getitem__(self, index):
189194
return self.dataset[index % len(self.dataset)]
190195

191196

197+
def _is_distributed_evaluation_active() -> bool:
198+
return torch.distributed.is_available() and torch.distributed.is_initialized()
199+
200+
201+
class DistributedEvaluationSampler(Sampler[int]):
202+
"""Shard evaluation samples across DDP ranks without padding or duplication."""
203+
204+
def __init__(
205+
self,
206+
dataset,
207+
*,
208+
rank: Optional[int] = None,
209+
world_size: Optional[int] = None,
210+
):
211+
if rank is None or world_size is None:
212+
if not _is_distributed_evaluation_active():
213+
raise RuntimeError(
214+
"DistributedEvaluationSampler requires an initialized distributed process "
215+
"group or explicit rank/world_size."
216+
)
217+
rank = torch.distributed.get_rank()
218+
world_size = torch.distributed.get_world_size()
219+
220+
if world_size <= 0:
221+
raise ValueError(f"world_size must be positive, got {world_size}.")
222+
if rank < 0 or rank >= world_size:
223+
raise ValueError(f"rank must satisfy 0 <= rank < world_size, got {rank}/{world_size}.")
224+
225+
self.rank = int(rank)
226+
self.world_size = int(world_size)
227+
self.indices = list(range(len(dataset)))[self.rank :: self.world_size]
228+
229+
def __iter__(self):
230+
return iter(self.indices)
231+
232+
def __len__(self):
233+
return len(self.indices)
234+
235+
192236
def collate_dict(
193237
batch: List[Dict[str, Any]],
194238
) -> Dict[str, Any]:
@@ -226,6 +270,7 @@ def collate_dict_list(
226270

227271
__all__ = [
228272
"ConnectomicsDataModule",
273+
"DistributedEvaluationSampler",
229274
"SimpleDataModule",
230275
"collate_dict",
231276
"collate_dict_list",

connectomics/training/lightning/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def create_trainer(
300300
detect_anomaly=detect_anomaly,
301301
enable_progress_bar=True,
302302
plugins=plugins,
303+
use_distributed_sampler=mode not in ("test", "tune-test"),
303304
)
304305

305306
_log.info(f" Training mode: {training_mode}")

scripts/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,15 @@ def maybe_limit_test_devices(cfg: Config, datamodule) -> bool:
399399

400400
tta_cfg = getattr(getattr(cfg, "inference", None), "test_time_augmentation", None)
401401
distributed_tta_sharding = bool(getattr(tta_cfg, "distributed_sharding", False))
402+
if distributed_tta_sharding and test_volume_count != 1:
403+
print(
404+
" WARNING: Disabling distributed TTA sharding for multi-volume test datasets. "
405+
"DDP ranks would otherwise reduce predictions from different volumes, which can "
406+
"mix samples or hang when shapes differ."
407+
)
408+
tta_cfg.distributed_sharding = False
409+
distributed_tta_sharding = False
410+
402411
if distributed_tta_sharding and test_volume_count == 1:
403412
safe_devices = max(1, min(requested_devices, _estimate_tta_total_passes(cfg)))
404413
if safe_devices < requested_devices:

tests/unit/test_connectomics_module.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Optional
22

3+
import numpy as np
34
import torch
45
import torch.nn as nn
56
import torch.nn.functional as F
@@ -173,3 +174,25 @@ def test_save_metrics_to_file_uses_runtime_inference_output_path(tmp_path):
173174
)
174175

175176
assert (tmp_path / "evaluation_metrics_vol0.txt").exists()
177+
178+
179+
def test_load_cached_predictions_reads_existing_prediction_files(tmp_path, monkeypatch):
180+
"""Existing cached predictions should load without falling back to inference."""
181+
cfg = _base_config()
182+
module = ConnectomicsModule(cfg, model=SimpleModel())
183+
pred_file = tmp_path / "sample_prediction.h5"
184+
pred_file.write_text("stub")
185+
186+
expected = np.ones((1, 4, 4, 4), dtype=np.float32)
187+
monkeypatch.setattr("connectomics.training.lightning.model.read_volume", lambda *_args, **_kwargs: expected)
188+
189+
predictions, loaded, suffix = module._load_cached_predictions(
190+
str(tmp_path),
191+
["sample"],
192+
"_prediction.h5",
193+
"test",
194+
)
195+
196+
assert loaded is True
197+
assert suffix == "_prediction.h5"
198+
assert predictions.shape == (1, 4, 4, 4)

tests/unit/test_inference_tta_masking.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,22 @@ def test_tta_channel_activations_follow_python_slice_semantics():
151151
expected[:, 0:2, ...] = torch.sigmoid(expected[:, 0:2, ...])
152152
expected[:, 2:3, ...] = torch.tanh(expected[:, 2:3, ...])
153153
assert torch.allclose(pred, expected)
154+
155+
156+
def test_distributed_tta_reduction_raises_on_mismatched_rank_shapes(monkeypatch):
157+
cfg = Config()
158+
predictor = TTAPredictor(cfg=cfg, sliding_inferer=None, forward_fn=_forward_constant)
159+
160+
monkeypatch.setattr(predictor, "_distributed_context", lambda: (True, 0, 2))
161+
162+
def _fake_all_gather(output_tensors, _input_tensor):
163+
output_tensors[0].copy_(torch.tensor([5, 1, 2, 4, 4, 4, 0], dtype=torch.int64))
164+
output_tensors[1].copy_(torch.tensor([5, 1, 2, 6, 4, 4, 0], dtype=torch.int64))
165+
166+
monkeypatch.setattr(torch.distributed, "all_gather", _fake_all_gather)
167+
168+
with pytest.raises(RuntimeError, match="same shape"):
169+
predictor._validate_distributed_reduction_shape(
170+
torch.zeros((1, 2, 4, 4, 4), dtype=torch.float32),
171+
reduction_device=torch.device("cpu"),
172+
)

tests/unit/test_lightning_data_collate.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
from connectomics.config import Config
66
from connectomics.inference.output import resolve_output_filenames
7-
from connectomics.training.lightning.data import ConnectomicsDataModule, collate_dict
7+
from connectomics.training.lightning.data import (
8+
ConnectomicsDataModule,
9+
DistributedEvaluationSampler,
10+
collate_dict,
11+
)
812

913

1014
def test_test_dataloader_preserves_variable_shape_tensors():
@@ -53,3 +57,17 @@ def test_resolve_output_filenames_supports_list_collated_images():
5357
}
5458

5559
assert resolve_output_filenames(cfg, batch, global_step=11) == ["input_a", "input_b"]
60+
61+
62+
def test_distributed_evaluation_sampler_partitions_without_duplicates():
63+
dataset = list(range(10))
64+
65+
rank0 = list(DistributedEvaluationSampler(dataset, rank=0, world_size=4))
66+
rank1 = list(DistributedEvaluationSampler(dataset, rank=1, world_size=4))
67+
rank2 = list(DistributedEvaluationSampler(dataset, rank=2, world_size=4))
68+
rank3 = list(DistributedEvaluationSampler(dataset, rank=3, world_size=4))
69+
70+
combined = rank0 + rank1 + rank2 + rank3
71+
72+
assert sorted(combined) == list(range(10))
73+
assert len(set(combined)) == 10

tests/unit/test_main_runtime_stage_switch.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from connectomics.config import Config, save_config
55
from connectomics.config.schema.inference import EvaluationConfig
66
from connectomics.training.lightning.utils import setup_config
7-
from scripts.main import _is_test_evaluation_enabled, resolve_test_stage_runtime
7+
from scripts.main import (
8+
_is_test_evaluation_enabled,
9+
maybe_limit_test_devices,
10+
resolve_test_stage_runtime,
11+
)
812

913

1014
def _make_args(config_path: Path, mode: str = "test"):
@@ -69,3 +73,36 @@ def test_is_test_evaluation_enabled_supports_mapping_or_dataclass_config():
6973

7074
cfg.inference.evaluation.enabled = True
7175
assert _is_test_evaluation_enabled(cfg) is True
76+
77+
78+
class _DummyTestDataModule:
79+
def __init__(self, volume_count: int):
80+
self.test_data_dicts = [{} for _ in range(volume_count)]
81+
82+
83+
def test_maybe_limit_test_devices_disables_distributed_tta_sharding_for_multi_volume_tests():
84+
cfg = Config()
85+
cfg.system.num_gpus = 4
86+
cfg.inference.test_time_augmentation.enabled = True
87+
cfg.inference.test_time_augmentation.distributed_sharding = True
88+
89+
changed = maybe_limit_test_devices(cfg, _DummyTestDataModule(volume_count=2))
90+
91+
assert changed is True
92+
assert cfg.system.num_gpus == 2
93+
assert cfg.inference.test_time_augmentation.distributed_sharding is False
94+
95+
96+
def test_maybe_limit_test_devices_keeps_distributed_tta_sharding_for_single_volume_tests():
97+
cfg = Config()
98+
cfg.system.num_gpus = 4
99+
cfg.inference.test_time_augmentation.enabled = True
100+
cfg.inference.test_time_augmentation.distributed_sharding = True
101+
cfg.inference.test_time_augmentation.flip_axes = [1, 2]
102+
cfg.inference.test_time_augmentation.rotation90_axes = [[1, 2]]
103+
104+
changed = maybe_limit_test_devices(cfg, _DummyTestDataModule(volume_count=1))
105+
106+
assert changed is False
107+
assert cfg.system.num_gpus == 4
108+
assert cfg.inference.test_time_augmentation.distributed_sharding is True

tests/unit/test_trainer_logging.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,27 @@ def test_create_trainer_disables_logger_for_non_train_modes(tmp_path: Path, mode
1919

2020
assert trainer.logger is None
2121
assert not (tmp_path / "logs").exists()
22+
23+
24+
def test_create_trainer_disables_lightning_distributed_sampler_replacement_for_test(
25+
tmp_path: Path, monkeypatch
26+
):
27+
cfg = from_dict(
28+
{
29+
"system": {"num_gpus": 0},
30+
"optimization": {"max_epochs": 1},
31+
}
32+
)
33+
captured = {}
34+
35+
class _FakeTrainer:
36+
def __init__(self, **kwargs):
37+
captured.update(kwargs)
38+
self.logger = kwargs.get("logger")
39+
40+
monkeypatch.setattr("connectomics.training.lightning.trainer.pl.Trainer", _FakeTrainer)
41+
42+
trainer = create_trainer(cfg, run_dir=tmp_path, mode="test")
43+
44+
assert isinstance(trainer, _FakeTrainer)
45+
assert captured["use_distributed_sampler"] is False

0 commit comments

Comments
 (0)