@@ -445,66 +445,93 @@ def test_rank_aware_sampler_initialization(self):
445445 assert sampler ._states == {}
446446
447447 def test_rank_aware_sampler_first_rank_sampling (self ):
448- """Test that first rank in DP group performs actual sampling."""
448+ """Test that first rank in data replica group performs actual sampling."""
449449 sampler = RankAwareSampler ()
450450 ready_indexes = [0 , 1 , 2 , 3 , 4 , 5 ]
451451 batch_size = 3
452452
453- # When world_size == dp_world_size, fetches_per_batch = 1
454- # First rank samples and immediately marks consumed (no other ranks to wait for)
455- sampled , consumed = sampler .sample (ready_indexes , batch_size , dp_group = 0 , dp_world_size = 2 , world_size = 2 )
453+ # Rank 0 (first in group) samples and caches for all ranks
454+ # Since rank 1 will call next, state is kept until rank 1 fetches
455+ sampled , consumed = sampler .sample (
456+ ready_indexes ,
457+ batch_size ,
458+ data_replica_group = 0 ,
459+ data_replica_rank = 0 ,
460+ data_replica_world_size = 2 ,
461+ task_name = "task" ,
462+ partition_id = "test" ,
463+ )
456464
457465 assert sampled == [0 , 1 , 2 ]
458- # consumed is returned
459466 assert consumed == [0 , 1 , 2 ]
460467 assert len (sampled ) == batch_size
461- # State should be cleaned up
462- assert sampler ._states == {}
468+ # State is kept for other ranks to fetch
463469
464470 def test_rank_aware_sampler_second_rank_gets_cached (self ):
465- """Test that second rank in DP group gets cached indices."""
471+ """Test that second rank in data replica group gets cached indices."""
466472 sampler = RankAwareSampler ()
467473 ready_indexes = [0 , 1 , 2 , 3 , 4 , 5 ]
468474 batch_size = 3
469- dp_world_size = 2
470- world_size = 4 # Use world_size=4 so fetches_per_batch=2
471475
472- # Rank 0 (dp_group=0 ) samples first
476+ # Rank 0 (first in group ) samples first
473477 sampled1 , consumed1 = sampler .sample (
474- ready_indexes , batch_size , dp_group = 0 , dp_world_size = dp_world_size , world_size = world_size
478+ ready_indexes ,
479+ batch_size ,
480+ data_replica_group = 0 ,
481+ data_replica_rank = 0 ,
482+ data_replica_world_size = 2 ,
483+ task_name = "task" ,
484+ partition_id = "test" ,
475485 )
476486
477- # Rank 1 (dp_group=0 ) should get same cached indices
487+ # Rank 1 (second in group ) should get same cached indices
478488 sampled2 , consumed2 = sampler .sample (
479- ready_indexes , batch_size , dp_group = 0 , dp_world_size = dp_world_size , world_size = world_size
489+ ready_indexes ,
490+ batch_size ,
491+ data_replica_group = 0 ,
492+ data_replica_rank = 1 ,
493+ data_replica_world_size = 2 ,
494+ task_name = "task" ,
495+ partition_id = "test" ,
480496 )
481497
482498 assert sampled1 == sampled2 == [0 , 1 , 2 ]
483- # First rank already returns consumed indexes
484499 assert consumed1 == [0 , 1 , 2 ]
485- # Second rank also sees the same consumed indexes; state is then cleaned up
486500 assert consumed2 == [0 , 1 , 2 ]
487- # State should be cleaned up
488- assert sampler ._states == {}
501+
502+ # cache should be empty after all ranks fetch
503+ assert len (sampler ._states ["test" ]["task" ][0 ][0 ]) == 0
504+ assert len (sampler ._states ["test" ]["task" ][0 ][1 ]) == 0
489505
490506 def test_rank_aware_sampler_multiple_dp_groups (self ):
491- """Test that multiple DP groups work independently."""
507+ """Test that multiple data replica groups work independently."""
492508 sampler = RankAwareSampler ()
493509 ready_indexes = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]
494510 batch_size = 2
495- dp_world_size = 4
496- world_size = 8
511+ data_replica_world_size = 2 # Each group has 2 ranks
497512
498- # DP group 0: rank 0 samples first
513+ # data replica group 0: rank 0 samples first
499514 sampled0_g0 , consumed0_g0 = sampler .sample (
500- ready_indexes , batch_size , dp_group = 0 , dp_world_size = dp_world_size , world_size = world_size
515+ ready_indexes ,
516+ batch_size ,
517+ data_replica_group = 0 ,
518+ data_replica_rank = 0 ,
519+ data_replica_world_size = data_replica_world_size ,
520+ task_name = "task" ,
521+ partition_id = "test" ,
501522 )
502523 # mimic the consumption status update managed in TransferQueueController
503524 ready_indexes = [i for i in ready_indexes if i not in consumed0_g0 ]
504525
505- # DP group 1: rank 0 samples first
526+ # data replica group 1: rank 0 samples first
506527 sampled0_g1 , consumed0_g1 = sampler .sample (
507- ready_indexes , batch_size , dp_group = 1 , dp_world_size = dp_world_size , world_size = world_size
528+ ready_indexes ,
529+ batch_size ,
530+ data_replica_group = 1 ,
531+ data_replica_rank = 0 ,
532+ data_replica_world_size = data_replica_world_size ,
533+ task_name = "task" ,
534+ partition_id = "test" ,
508535 )
509536 ready_indexes = [i for i in ready_indexes if i not in consumed0_g1 ]
510537
@@ -514,47 +541,82 @@ def test_rank_aware_sampler_multiple_dp_groups(self):
514541 assert consumed0_g0 == [0 , 1 ]
515542 assert consumed0_g1 == [2 , 3 ]
516543
517- # DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed
544+ # data replica group 0: rank 1 fetches cached
518545 sampled1_g0 , consumed1_g0 = sampler .sample (
519- ready_indexes , batch_size , dp_group = 0 , dp_world_size = dp_world_size , world_size = world_size
546+ ready_indexes ,
547+ batch_size ,
548+ data_replica_group = 0 ,
549+ data_replica_rank = 1 ,
550+ data_replica_world_size = data_replica_world_size ,
551+ task_name = "task" ,
552+ partition_id = "test" ,
520553 )
521554 ready_indexes = [i for i in ready_indexes if i not in consumed1_g0 ]
522555 assert sampled1_g0 == [0 , 1 ]
523556 assert consumed1_g0 == [0 , 1 ]
524557
525- # DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed
558+ # data replica group 1: rank 1 fetches cached
526559 sampled1_g1 , consumed1_g1 = sampler .sample (
527- ready_indexes , batch_size , dp_group = 1 , dp_world_size = dp_world_size , world_size = world_size
560+ ready_indexes ,
561+ batch_size ,
562+ data_replica_group = 1 ,
563+ data_replica_rank = 1 ,
564+ data_replica_world_size = data_replica_world_size ,
565+ task_name = "task" ,
566+ partition_id = "test" ,
528567 )
529568 ready_indexes = [i for i in ready_indexes if i not in consumed1_g1 ]
530569 assert sampled1_g1 == [2 , 3 ]
531570 assert consumed1_g1 == [2 , 3 ]
532571
533- # DP group 0: rank 0 fetches again, this should return new data
572+ # data replica group 0: rank 0 fetches again, this should return new data
534573 sampled2_g0 , consumed2_g0 = sampler .sample (
535- ready_indexes , batch_size , dp_group = 0 , dp_world_size = dp_world_size , world_size = world_size
574+ ready_indexes ,
575+ batch_size ,
576+ data_replica_group = 0 ,
577+ data_replica_rank = 0 ,
578+ data_replica_world_size = data_replica_world_size ,
579+ task_name = "task" ,
580+ partition_id = "test" ,
536581 )
537582 ready_indexes = [i for i in ready_indexes if i not in consumed2_g0 ]
538583 assert sampled2_g0 == [4 , 5 ]
539584 assert consumed2_g0 == [4 , 5 ]
540585
541- # DP group 0: rank 1 fetches cached
586+ # data replica group 0: rank 1 fetches cached
542587 sampled3_g0 , consumed3_g0 = sampler .sample (
543- ready_indexes , batch_size , dp_group = 0 , dp_world_size = dp_world_size , world_size = world_size
588+ ready_indexes ,
589+ batch_size ,
590+ data_replica_group = 0 ,
591+ data_replica_rank = 1 ,
592+ data_replica_world_size = data_replica_world_size ,
593+ task_name = "task" ,
594+ partition_id = "test" ,
544595 )
545596 assert sampled3_g0 == [4 , 5 ]
546597 assert consumed3_g0 == [4 , 5 ]
547598
548- # Both groups should be cleaned up
549- assert sampler ._states == {}
599+ # examine the internal state to ensure proper caching and clearing
600+ assert len (sampler ._states ["test" ]["task" ][0 ][0 ]) == 0
601+ assert len (sampler ._states ["test" ]["task" ][0 ][1 ]) == 0
602+ assert len (sampler ._states ["test" ]["task" ][1 ][0 ]) == 0
603+ assert len (sampler ._states ["test" ]["task" ][1 ][1 ]) == 0
550604
551605 def test_rank_aware_sampler_empty_ready_indexes (self ):
552606 """Test behavior with empty ready indexes."""
553607 sampler = RankAwareSampler ()
554608 ready_indexes = []
555609 batch_size = 3
556610
557- sampled , consumed = sampler .sample (ready_indexes , batch_size , dp_group = 0 , dp_world_size = 2 , world_size = 2 )
611+ sampled , consumed = sampler .sample (
612+ ready_indexes ,
613+ batch_size ,
614+ data_replica_group = 0 ,
615+ data_replica_rank = 0 ,
616+ data_replica_world_size = 2 ,
617+ task_name = "task" ,
618+ partition_id = "test" ,
619+ )
558620
559621 assert sampled == []
560622 assert consumed == []
@@ -565,8 +627,15 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self):
565627 ready_indexes = [0 , 1 ]
566628 batch_size = 5
567629
568- # When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately
569- sampled , consumed = sampler .sample (ready_indexes , batch_size , dp_group = 0 , dp_world_size = 2 , world_size = 2 )
630+ sampled , consumed = sampler .sample (
631+ ready_indexes ,
632+ batch_size ,
633+ data_replica_group = 0 ,
634+ data_replica_rank = 0 ,
635+ data_replica_world_size = 2 ,
636+ task_name = "task" ,
637+ partition_id = "test" ,
638+ )
570639
571640 assert sampled == []
572641 assert consumed == []
@@ -577,11 +646,112 @@ def test_rank_aware_sampler_zero_batch_size(self):
577646 ready_indexes = [0 , 1 , 2 , 3 ]
578647 batch_size = 0
579648
580- sampled , consumed = sampler .sample (ready_indexes , batch_size , dp_group = 0 , dp_world_size = 2 , world_size = 2 )
649+ sampled , consumed = sampler .sample (
650+ ready_indexes ,
651+ batch_size ,
652+ data_replica_group = 0 ,
653+ data_replica_rank = 0 ,
654+ data_replica_world_size = 2 ,
655+ task_name = "task" ,
656+ partition_id = "test" ,
657+ )
581658
582659 assert sampled == []
583660 assert consumed == []
584661
662+ def test_rank_aware_sampler_data_prefetch (self ):
663+ """Test behavior with data prefetch."""
664+ sampler = RankAwareSampler ()
665+ ready_indexes = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]
666+ batch_size = 2
667+
668+ sampled_rank0_time0 , consumed_rank0_time0 = sampler .sample (
669+ ready_indexes ,
670+ batch_size ,
671+ data_replica_group = 0 ,
672+ data_replica_rank = 0 ,
673+ data_replica_world_size = 2 ,
674+ task_name = "task" ,
675+ partition_id = "test" ,
676+ )
677+
678+ assert sampled_rank0_time0 == [0 , 1 ]
679+ assert consumed_rank0_time0 == [0 , 1 ]
680+ assert sampler ._states ["test" ]["task" ][0 ][0 ] == []
681+ assert sampler ._states ["test" ]["task" ][0 ][1 ] == [[0 , 1 ]]
682+
683+ ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0 ]
684+
685+ sampled_rank0_time1 , consumed_rank0_time1 = sampler .sample (
686+ ready_indexes ,
687+ batch_size ,
688+ data_replica_group = 0 ,
689+ data_replica_rank = 0 ,
690+ data_replica_world_size = 2 ,
691+ task_name = "task" ,
692+ partition_id = "test" ,
693+ )
694+
695+ assert sampled_rank0_time1 == [2 , 3 ]
696+ assert consumed_rank0_time1 == [2 , 3 ]
697+ assert sampler ._states ["test" ]["task" ][0 ][0 ] == []
698+ assert sampler ._states ["test" ]["task" ][0 ][1 ] == [[0 , 1 ], [2 , 3 ]]
699+
700+ ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1 ]
701+
702+ sampled_rank1_time0 , consumed_rank1_time0 = sampler .sample (
703+ ready_indexes ,
704+ batch_size ,
705+ data_replica_group = 0 ,
706+ data_replica_rank = 1 ,
707+ data_replica_world_size = 2 ,
708+ task_name = "task" ,
709+ partition_id = "test" ,
710+ )
711+ assert sampled_rank1_time0 == [0 , 1 ]
712+ assert consumed_rank1_time0 == [0 , 1 ]
713+
714+ assert sampler ._states ["test" ]["task" ][0 ][0 ] == []
715+ assert sampler ._states ["test" ]["task" ][0 ][1 ] == [[2 , 3 ]]
716+
717+ def test_rank_aware_sampler_multiple_tasks (self ):
718+ """Test behavior with multiple tasks."""
719+ sampler = RankAwareSampler ()
720+ ready_indexes = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]
721+ batch_size = 2
722+
723+ sampled_rank0_task0 , consumed_rank0_task0 = sampler .sample (
724+ ready_indexes ,
725+ batch_size ,
726+ data_replica_group = 0 ,
727+ data_replica_rank = 0 ,
728+ data_replica_world_size = 2 ,
729+ task_name = "task0" ,
730+ partition_id = "test" ,
731+ )
732+
733+ assert sampled_rank0_task0 == [0 , 1 ]
734+ assert consumed_rank0_task0 == [0 , 1 ]
735+ assert sampler ._states ["test" ]["task0" ][0 ][0 ] == []
736+ assert sampler ._states ["test" ]["task0" ][0 ][1 ] == [[0 , 1 ]]
737+
738+ sampled_rank0_task1 , consumed_rank0_task1 = sampler .sample (
739+ ready_indexes ,
740+ batch_size ,
741+ data_replica_group = 0 ,
742+ data_replica_rank = 0 ,
743+ data_replica_world_size = 2 ,
744+ task_name = "task1" ,
745+ partition_id = "test" ,
746+ )
747+
748+ assert sampled_rank0_task1 == [0 , 1 ]
749+ assert consumed_rank0_task1 == [0 , 1 ]
750+ assert sampler ._states ["test" ]["task0" ][0 ][0 ] == []
751+ assert sampler ._states ["test" ]["task0" ][0 ][1 ] == [[0 , 1 ]]
752+ assert sampler ._states ["test" ]["task1" ][0 ][0 ] == []
753+ assert sampler ._states ["test" ]["task1" ][0 ][1 ] == [[0 , 1 ]]
754+
585755
586756class TestSamplerIntegration :
587757 """Integration tests for samplers."""
0 commit comments