@@ -153,8 +153,10 @@ def _should_skip_postprocess_on_rank(module) -> bool:
153153def _distributed_tta_barrier (module ) -> None :
154154 if not _is_distributed_tta_sharding_active (module ):
155155 return
156- if torch .distributed .is_available () and torch .distributed .is_initialized ():
157- torch .distributed .barrier ()
156+ # Rank 0 may spend a long time in CPU-side decoding/evaluation after the
157+ # distributed TTA reduction. Holding nonzero ranks in an NCCL barrier here
158+ # can trip the watchdog during long ABISS runs, so treat this as a no-op.
159+ return
158160
159161
160162def _is_unstacked_test_batch (batch : Dict [str , Any ]) -> bool :
@@ -580,7 +582,16 @@ def log_test_epoch_metrics(module) -> None:
580582 if not module ._is_test_evaluation_enabled ():
581583 return
582584
583- if hasattr (module , "test_adapted_rand" ) and isinstance (module .test_adapted_rand , torchmetrics .Metric ):
585+ is_dist = torch .distributed .is_available () and torch .distributed .is_initialized ()
586+ rank = torch .distributed .get_rank () if is_dist else 0
587+ distributed_tta_sharding = _is_distributed_tta_sharding_active (module )
588+ if distributed_tta_sharding and rank != 0 :
589+ return
590+ sync_dist = not distributed_tta_sharding
591+
592+ if hasattr (module , "test_adapted_rand" ) and isinstance (
593+ module .test_adapted_rand , torchmetrics .Metric
594+ ):
584595 epoch_stats = module .test_adapted_rand .compute ()
585596 if isinstance (epoch_stats , dict ):
586597 module .log (
@@ -590,7 +601,7 @@ def log_test_epoch_metrics(module) -> None:
590601 on_epoch = True ,
591602 prog_bar = True ,
592603 logger = True ,
593- sync_dist = True ,
604+ sync_dist = sync_dist ,
594605 )
595606 module .log (
596607 "test_adapted_rand_precision" ,
@@ -599,7 +610,7 @@ def log_test_epoch_metrics(module) -> None:
599610 on_epoch = True ,
600611 prog_bar = True ,
601612 logger = True ,
602- sync_dist = True ,
613+ sync_dist = sync_dist ,
603614 )
604615 module .log (
605616 "test_adapted_rand_recall" ,
@@ -608,7 +619,7 @@ def log_test_epoch_metrics(module) -> None:
608619 on_epoch = True ,
609620 prog_bar = True ,
610621 logger = True ,
611- sync_dist = True ,
622+ sync_dist = sync_dist ,
612623 )
613624 else :
614625 module .log (
@@ -618,7 +629,7 @@ def log_test_epoch_metrics(module) -> None:
618629 on_epoch = True ,
619630 prog_bar = True ,
620631 logger = True ,
621- sync_dist = True ,
632+ sync_dist = sync_dist ,
622633 )
623634
624635 if hasattr (module , "test_voi" ) and isinstance (module .test_voi , torchmetrics .Metric ):
@@ -629,7 +640,7 @@ def log_test_epoch_metrics(module) -> None:
629640 on_epoch = True ,
630641 prog_bar = True ,
631642 logger = True ,
632- sync_dist = True ,
643+ sync_dist = sync_dist ,
633644 )
634645 module .log (
635646 "test_voi_split" ,
@@ -638,7 +649,7 @@ def log_test_epoch_metrics(module) -> None:
638649 on_epoch = True ,
639650 prog_bar = False ,
640651 logger = True ,
641- sync_dist = True ,
652+ sync_dist = sync_dist ,
642653 )
643654 module .log (
644655 "test_voi_merge" ,
@@ -647,7 +658,7 @@ def log_test_epoch_metrics(module) -> None:
647658 on_epoch = True ,
648659 prog_bar = False ,
649660 logger = True ,
650- sync_dist = True ,
661+ sync_dist = sync_dist ,
651662 )
652663
653664 if hasattr (module , "test_instance_accuracy" ) and isinstance (
@@ -660,7 +671,7 @@ def log_test_epoch_metrics(module) -> None:
660671 on_epoch = True ,
661672 prog_bar = True ,
662673 logger = True ,
663- sync_dist = True ,
674+ sync_dist = sync_dist ,
664675 )
665676
666677 if hasattr (module , "test_instance_accuracy_detail" ) and isinstance (
@@ -673,7 +684,7 @@ def log_test_epoch_metrics(module) -> None:
673684 on_epoch = True ,
674685 prog_bar = True ,
675686 logger = True ,
676- sync_dist = True ,
687+ sync_dist = sync_dist ,
677688 )
678689 module .log (
679690 "test_instance_precision_detail" ,
@@ -682,7 +693,7 @@ def log_test_epoch_metrics(module) -> None:
682693 on_epoch = True ,
683694 prog_bar = False ,
684695 logger = True ,
685- sync_dist = True ,
696+ sync_dist = sync_dist ,
686697 )
687698 module .log (
688699 "test_instance_recall_detail" ,
@@ -691,7 +702,7 @@ def log_test_epoch_metrics(module) -> None:
691702 on_epoch = True ,
692703 prog_bar = False ,
693704 logger = True ,
694- sync_dist = True ,
705+ sync_dist = sync_dist ,
695706 )
696707 module .log (
697708 "test_instance_f1_detail" ,
@@ -700,7 +711,7 @@ def log_test_epoch_metrics(module) -> None:
700711 on_epoch = True ,
701712 prog_bar = False ,
702713 logger = True ,
703- sync_dist = True ,
714+ sync_dist = sync_dist ,
704715 )
705716
706717 if hasattr (module , "test_jaccard" ) and module .test_jaccard is not None :
@@ -711,7 +722,7 @@ def log_test_epoch_metrics(module) -> None:
711722 on_epoch = True ,
712723 prog_bar = True ,
713724 logger = True ,
714- sync_dist = True ,
725+ sync_dist = sync_dist ,
715726 )
716727
717728 if hasattr (module , "test_dice" ) and module .test_dice is not None :
@@ -722,7 +733,7 @@ def log_test_epoch_metrics(module) -> None:
722733 on_epoch = True ,
723734 prog_bar = True ,
724735 logger = True ,
725- sync_dist = True ,
736+ sync_dist = sync_dist ,
726737 )
727738
728739 if hasattr (module , "test_accuracy" ) and module .test_accuracy is not None :
@@ -733,7 +744,7 @@ def log_test_epoch_metrics(module) -> None:
733744 on_epoch = True ,
734745 prog_bar = True ,
735746 logger = True ,
736- sync_dist = True ,
747+ sync_dist = sync_dist ,
737748 )
738749
739750
@@ -858,7 +869,9 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
858869
859870 if lazy_sample :
860871 image_path = str (batch ["image" ])
861- mask_path = str (batch ["mask" ]) if isinstance (batch .get ("mask" ), (str , os .PathLike )) else None
872+ mask_path = (
873+ str (batch ["mask" ]) if isinstance (batch .get ("mask" ), (str , os .PathLike )) else None
874+ )
862875 labels = None
863876 reference_image_shape = get_lazy_image_reference_shape (module .cfg , image_path , mode = mode )
864877 crop_pad = _resolve_postprocessing_crop_pad (module )
0 commit comments