Skip to content

Commit 9e2bd43

Browse files
author
Donglai Wei
committed
Avoid NCCL hangs after distributed TTA decode
1 parent 124317b commit 9e2bd43

2 files changed

Lines changed: 101 additions & 22 deletions

File tree

connectomics/training/lightning/test_pipeline.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,10 @@ def _should_skip_postprocess_on_rank(module) -> bool:
153153
def _distributed_tta_barrier(module) -> None:
154154
if not _is_distributed_tta_sharding_active(module):
155155
return
156-
if torch.distributed.is_available() and torch.distributed.is_initialized():
157-
torch.distributed.barrier()
156+
# Rank 0 may spend a long time in CPU-side decoding/evaluation after the
157+
# distributed TTA reduction. Holding nonzero ranks in an NCCL barrier here
158+
# can trip the watchdog during long ABISS runs, so treat this as a no-op.
159+
return
158160

159161

160162
def _is_unstacked_test_batch(batch: Dict[str, Any]) -> bool:
@@ -580,7 +582,16 @@ def log_test_epoch_metrics(module) -> None:
580582
if not module._is_test_evaluation_enabled():
581583
return
582584

583-
if hasattr(module, "test_adapted_rand") and isinstance(module.test_adapted_rand, torchmetrics.Metric):
585+
is_dist = torch.distributed.is_available() and torch.distributed.is_initialized()
586+
rank = torch.distributed.get_rank() if is_dist else 0
587+
distributed_tta_sharding = _is_distributed_tta_sharding_active(module)
588+
if distributed_tta_sharding and rank != 0:
589+
return
590+
sync_dist = not distributed_tta_sharding
591+
592+
if hasattr(module, "test_adapted_rand") and isinstance(
593+
module.test_adapted_rand, torchmetrics.Metric
594+
):
584595
epoch_stats = module.test_adapted_rand.compute()
585596
if isinstance(epoch_stats, dict):
586597
module.log(
@@ -590,7 +601,7 @@ def log_test_epoch_metrics(module) -> None:
590601
on_epoch=True,
591602
prog_bar=True,
592603
logger=True,
593-
sync_dist=True,
604+
sync_dist=sync_dist,
594605
)
595606
module.log(
596607
"test_adapted_rand_precision",
@@ -599,7 +610,7 @@ def log_test_epoch_metrics(module) -> None:
599610
on_epoch=True,
600611
prog_bar=True,
601612
logger=True,
602-
sync_dist=True,
613+
sync_dist=sync_dist,
603614
)
604615
module.log(
605616
"test_adapted_rand_recall",
@@ -608,7 +619,7 @@ def log_test_epoch_metrics(module) -> None:
608619
on_epoch=True,
609620
prog_bar=True,
610621
logger=True,
611-
sync_dist=True,
622+
sync_dist=sync_dist,
612623
)
613624
else:
614625
module.log(
@@ -618,7 +629,7 @@ def log_test_epoch_metrics(module) -> None:
618629
on_epoch=True,
619630
prog_bar=True,
620631
logger=True,
621-
sync_dist=True,
632+
sync_dist=sync_dist,
622633
)
623634

624635
if hasattr(module, "test_voi") and isinstance(module.test_voi, torchmetrics.Metric):
@@ -629,7 +640,7 @@ def log_test_epoch_metrics(module) -> None:
629640
on_epoch=True,
630641
prog_bar=True,
631642
logger=True,
632-
sync_dist=True,
643+
sync_dist=sync_dist,
633644
)
634645
module.log(
635646
"test_voi_split",
@@ -638,7 +649,7 @@ def log_test_epoch_metrics(module) -> None:
638649
on_epoch=True,
639650
prog_bar=False,
640651
logger=True,
641-
sync_dist=True,
652+
sync_dist=sync_dist,
642653
)
643654
module.log(
644655
"test_voi_merge",
@@ -647,7 +658,7 @@ def log_test_epoch_metrics(module) -> None:
647658
on_epoch=True,
648659
prog_bar=False,
649660
logger=True,
650-
sync_dist=True,
661+
sync_dist=sync_dist,
651662
)
652663

653664
if hasattr(module, "test_instance_accuracy") and isinstance(
@@ -660,7 +671,7 @@ def log_test_epoch_metrics(module) -> None:
660671
on_epoch=True,
661672
prog_bar=True,
662673
logger=True,
663-
sync_dist=True,
674+
sync_dist=sync_dist,
664675
)
665676

666677
if hasattr(module, "test_instance_accuracy_detail") and isinstance(
@@ -673,7 +684,7 @@ def log_test_epoch_metrics(module) -> None:
673684
on_epoch=True,
674685
prog_bar=True,
675686
logger=True,
676-
sync_dist=True,
687+
sync_dist=sync_dist,
677688
)
678689
module.log(
679690
"test_instance_precision_detail",
@@ -682,7 +693,7 @@ def log_test_epoch_metrics(module) -> None:
682693
on_epoch=True,
683694
prog_bar=False,
684695
logger=True,
685-
sync_dist=True,
696+
sync_dist=sync_dist,
686697
)
687698
module.log(
688699
"test_instance_recall_detail",
@@ -691,7 +702,7 @@ def log_test_epoch_metrics(module) -> None:
691702
on_epoch=True,
692703
prog_bar=False,
693704
logger=True,
694-
sync_dist=True,
705+
sync_dist=sync_dist,
695706
)
696707
module.log(
697708
"test_instance_f1_detail",
@@ -700,7 +711,7 @@ def log_test_epoch_metrics(module) -> None:
700711
on_epoch=True,
701712
prog_bar=False,
702713
logger=True,
703-
sync_dist=True,
714+
sync_dist=sync_dist,
704715
)
705716

706717
if hasattr(module, "test_jaccard") and module.test_jaccard is not None:
@@ -711,7 +722,7 @@ def log_test_epoch_metrics(module) -> None:
711722
on_epoch=True,
712723
prog_bar=True,
713724
logger=True,
714-
sync_dist=True,
725+
sync_dist=sync_dist,
715726
)
716727

717728
if hasattr(module, "test_dice") and module.test_dice is not None:
@@ -722,7 +733,7 @@ def log_test_epoch_metrics(module) -> None:
722733
on_epoch=True,
723734
prog_bar=True,
724735
logger=True,
725-
sync_dist=True,
736+
sync_dist=sync_dist,
726737
)
727738

728739
if hasattr(module, "test_accuracy") and module.test_accuracy is not None:
@@ -733,7 +744,7 @@ def log_test_epoch_metrics(module) -> None:
733744
on_epoch=True,
734745
prog_bar=True,
735746
logger=True,
736-
sync_dist=True,
747+
sync_dist=sync_dist,
737748
)
738749

739750

@@ -858,7 +869,9 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
858869

859870
if lazy_sample:
860871
image_path = str(batch["image"])
861-
mask_path = str(batch["mask"]) if isinstance(batch.get("mask"), (str, os.PathLike)) else None
872+
mask_path = (
873+
str(batch["mask"]) if isinstance(batch.get("mask"), (str, os.PathLike)) else None
874+
)
862875
labels = None
863876
reference_image_shape = get_lazy_image_reference_shape(module.cfg, image_path, mode=mode)
864877
crop_pad = _resolve_postprocessing_crop_pad(module)

tests/unit/test_connectomics_module.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def test_load_cached_predictions_reads_existing_prediction_files(tmp_path, monke
210210
pred_file.write_text("stub")
211211

212212
expected = np.ones((1, 4, 4, 4), dtype=np.float32)
213-
monkeypatch.setattr("connectomics.training.lightning.model.read_volume", lambda *_args, **_kwargs: expected)
213+
monkeypatch.setattr(
214+
"connectomics.training.lightning.model.read_volume", lambda *_args, **_kwargs: expected
215+
)
214216

215217
predictions, loaded, suffix = module._load_cached_predictions(
216218
str(tmp_path),
@@ -224,7 +226,7 @@ def test_load_cached_predictions_reads_existing_prediction_files(tmp_path, monke
224226
assert predictions.shape == (1, 4, 4, 4)
225227

226228

227-
def test_on_test_end_logs_aggregated_metrics_once():
229+
def test_on_test_epoch_end_logs_aggregated_metrics_once():
228230
cfg = _base_config()
229231
cfg.inference.evaluation.enabled = True
230232
cfg.inference.evaluation.metrics = ["accuracy"]
@@ -235,6 +237,70 @@ def test_on_test_end_logs_aggregated_metrics_once():
235237
module.test_accuracy = torchmetrics.Accuracy(task="binary")
236238
module.test_accuracy.update(torch.tensor([1, 0]), torch.tensor([1, 0]))
237239

238-
module.on_test_end()
240+
module.on_test_epoch_end()
239241

240242
assert logged_names == ["test_accuracy"]
243+
244+
245+
def test_log_test_epoch_metrics_uses_rank_zero_only_logging_for_distributed_tta_sharding(
246+
monkeypatch,
247+
):
248+
cfg = _base_config()
249+
cfg.inference.evaluation.enabled = True
250+
cfg.inference.evaluation.metrics = ["accuracy"]
251+
cfg.inference.test_time_augmentation.enabled = True
252+
cfg.inference.test_time_augmentation.distributed_sharding = True
253+
254+
module = ConnectomicsModule(cfg, model=SimpleModel())
255+
calls: list[tuple[str, bool]] = []
256+
257+
def log_override(name, *_args, **kwargs):
258+
calls.append((name, kwargs["sync_dist"]))
259+
return None
260+
261+
module.log = log_override
262+
module.test_accuracy = torchmetrics.Accuracy(task="binary")
263+
module.test_accuracy.update(torch.tensor([1, 0]), torch.tensor([1, 0]))
264+
265+
monkeypatch.setattr(
266+
"connectomics.training.lightning.test_pipeline._is_distributed_tta_sharding_active",
267+
lambda _module: True,
268+
)
269+
monkeypatch.setattr(torch.distributed, "is_available", lambda: True)
270+
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
271+
monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0)
272+
273+
log_test_epoch_metrics(module)
274+
275+
assert calls == [("test_accuracy", False)]
276+
277+
278+
def test_log_test_epoch_metrics_skips_nonzero_ranks_for_distributed_tta_sharding(monkeypatch):
279+
cfg = _base_config()
280+
cfg.inference.evaluation.enabled = True
281+
cfg.inference.evaluation.metrics = ["accuracy"]
282+
cfg.inference.test_time_augmentation.enabled = True
283+
cfg.inference.test_time_augmentation.distributed_sharding = True
284+
285+
module = ConnectomicsModule(cfg, model=SimpleModel())
286+
calls: list[str] = []
287+
288+
def log_override(name, *_args, **_kwargs):
289+
calls.append(name)
290+
return None
291+
292+
module.log = log_override
293+
module.test_accuracy = torchmetrics.Accuracy(task="binary")
294+
module.test_accuracy.update(torch.tensor([1, 0]), torch.tensor([1, 0]))
295+
296+
monkeypatch.setattr(
297+
"connectomics.training.lightning.test_pipeline._is_distributed_tta_sharding_active",
298+
lambda _module: True,
299+
)
300+
monkeypatch.setattr(torch.distributed, "is_available", lambda: True)
301+
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
302+
monkeypatch.setattr(torch.distributed, "get_rank", lambda: 1)
303+
304+
log_test_epoch_metrics(module)
305+
306+
assert calls == []

0 commit comments

Comments
 (0)