@@ -437,6 +437,99 @@ def maybe_limit_test_devices(cfg: Config, datamodule) -> bool:
437437 return True
438438
439439
440+ def resolve_test_rank_shard_from_env () -> tuple [int | None , int | None ]:
441+ """Return rank/world_size for externally launched multi-process test jobs."""
442+ for rank_key , world_key in (("RANK" , "WORLD_SIZE" ), ("SLURM_PROCID" , "SLURM_NTASKS" )):
443+ rank_raw = os .environ .get (rank_key )
444+ world_raw = os .environ .get (world_key )
445+ if rank_raw is None or world_raw is None :
446+ continue
447+ try :
448+ rank = int (rank_raw )
449+ world_size = int (world_raw )
450+ except ValueError :
451+ continue
452+ if world_size > 1 :
453+ return rank , world_size
454+
455+ return None , None
456+
457+
458+ def resolve_test_image_paths (cfg : Config ) -> list [str ]:
459+ """Resolve test image paths from config for shard planning."""
460+ data_cfg = getattr (cfg , "data" , None )
461+ test_image = getattr (getattr (data_cfg , "test" , None ), "image" , None )
462+ if not test_image :
463+ return []
464+
465+ from connectomics .training .lightning .path_utils import expand_file_paths
466+
467+ try :
468+ return expand_file_paths (test_image )
469+ except Exception as exc :
470+ print (f" WARNING: Failed to resolve test_image paths for sharding: { exc } " )
471+ return []
472+
473+
474+ def maybe_enable_independent_test_sharding (args , cfg : Config ) -> bool :
475+ """Run test as independent single-GPU shards instead of DDP when rank info is available."""
476+ requested_devices = int (getattr (cfg .system , "num_gpus" , 0 ) or 0 )
477+ if requested_devices <= 1 :
478+ return False
479+
480+ shard_id = getattr (args , "shard_id" , None )
481+ num_shards = getattr (args , "num_shards" , None )
482+ source = None
483+
484+ if shard_id is not None and num_shards is not None and int (num_shards ) > 1 :
485+ source = "explicit shard arguments"
486+ else :
487+ test_image_paths = resolve_test_image_paths (cfg )
488+ if len (test_image_paths ) <= 1 :
489+ return False
490+
491+ shard_id , num_shards = resolve_test_rank_shard_from_env ()
492+ if shard_id is None or num_shards is None :
493+ return False
494+
495+ args .shard_id = shard_id
496+ args .num_shards = num_shards
497+ source = "distributed launcher environment"
498+
499+ tta_cfg = getattr (getattr (cfg , "inference" , None ), "test_time_augmentation" , None )
500+ if tta_cfg is not None and bool (getattr (tta_cfg , "distributed_sharding" , False )):
501+ print (
502+ " WARNING: Disabling distributed TTA sharding for independent per-rank test sharding."
503+ )
504+ tta_cfg .distributed_sharding = False
505+
506+ cfg .system .num_gpus = 1 if torch .cuda .is_available () else 0
507+ print (
508+ " INFO: Independent multi-GPU test sharding enabled "
509+ f"({ source } ); each process will handle its own shard with no DDP communication."
510+ )
511+ return True
512+
513+
514+ def has_assigned_test_shard (cfg : Config , args ) -> bool :
515+ """Return True if the current shard has at least one test volume to process."""
516+ shard_id = getattr (args , "shard_id" , None )
517+ num_shards = getattr (args , "num_shards" , None )
518+ if shard_id is None or num_shards is None :
519+ return True
520+
521+ test_image_paths = resolve_test_image_paths (cfg )
522+ if not test_image_paths :
523+ return True
524+
525+ if test_image_paths [shard_id ::num_shards ]:
526+ return True
527+
528+ print (f" Shard { shard_id } /{ num_shards } is empty, nothing to do." )
529+ print ("[OK]Test completed successfully (empty shard)." )
530+ return False
531+
532+
440533def shard_test_datamodule (datamodule , shard_id : int , num_shards : int ):
441534 """Shard test volumes across machines.
442535
@@ -766,6 +859,11 @@ def main():
766859 print (f"Random seed set to: { cfg .system .seed } " )
767860 seed_everything (cfg .system .seed , workers = True )
768861
862+ if args .mode == "test" :
863+ maybe_enable_independent_test_sharding (args , cfg )
864+ if not has_assigned_test_shard (cfg , args ):
865+ return
866+
769867 # Cache-only preflight path for test mode (can skip model/trainer/dataloader entirely).
770868 if try_cache_only_test_execution (cfg , args .mode , args .shard_id , args .num_shards ):
771869 return
@@ -839,6 +937,17 @@ def main():
839937 # Re-resolve test-stage runtime overrides after tuning, including sentinels.
840938 cfg = resolve_test_stage_runtime (cfg )
841939
940+ if maybe_enable_independent_test_sharding (args , cfg ):
941+ trainer = create_trainer (
942+ cfg ,
943+ run_dir = run_dir ,
944+ fast_dev_run = args .fast_dev_run ,
945+ ckpt_path = ckpt_path ,
946+ mode = "test" ,
947+ )
948+ if not has_assigned_test_shard (cfg , args ):
949+ return
950+
842951 # Create datamodule
843952 datamodule = create_datamodule (cfg , mode = "test" )
844953
@@ -882,7 +991,7 @@ def main():
882991
883992 trainer .test (
884993 model ,
885- datamodule = datamodule ,
994+ datamodule ,
886995 ckpt_path = test_ckpt_path ,
887996 )
888997
0 commit comments