Skip to content

Commit 45cfafe

Browse files
authored
[StreamingDataLoader, 2/N] feat: support async sampling and data pre-fetch in RankAwareSampler (#7)
## Background In the initial implementation introduced in PR #4, `RankAwareSampler` allowed individual ranks to fetch `BatchMeta` from `TransferQueueController`, guaranteeing all ranks within the same data replica group receive the same sample indices.. However, this implementation had two main limitations: - It did not account for asynchronous calls arising from different tasks in a task-separated RL framework. - It did not support data pre-fetching when integrated with the `StreamingDataLoader` interface. ## Solution This PR enhances `RankAwareSampler` to support multi-task concurrency and data pre-fetching: - **Task & Partition Awareness**: Introduced `task_name` and `partition_id` parameters to correctly identify the current task context and apply distinct caching logic for each task. - **Pre-fetching Support**: Implemented a dynamic buffer for each rank under each task. CC: @NINGBENZHE --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent eb59ad0 commit 45cfafe

2 files changed

Lines changed: 279 additions & 87 deletions

File tree

tests/test_samplers.py

Lines changed: 209 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -445,66 +445,93 @@ def test_rank_aware_sampler_initialization(self):
445445
assert sampler._states == {}
446446

447447
def test_rank_aware_sampler_first_rank_sampling(self):
448-
"""Test that first rank in DP group performs actual sampling."""
448+
"""Test that first rank in data replica group performs actual sampling."""
449449
sampler = RankAwareSampler()
450450
ready_indexes = [0, 1, 2, 3, 4, 5]
451451
batch_size = 3
452452

453-
# When world_size == dp_world_size, fetches_per_batch = 1
454-
# First rank samples and immediately marks consumed (no other ranks to wait for)
455-
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
453+
# Rank 0 (first in group) samples and caches for all ranks
454+
# Since rank 1 will call next, state is kept until rank 1 fetches
455+
sampled, consumed = sampler.sample(
456+
ready_indexes,
457+
batch_size,
458+
data_replica_group=0,
459+
data_replica_rank=0,
460+
data_replica_world_size=2,
461+
task_name="task",
462+
partition_id="test",
463+
)
456464

457465
assert sampled == [0, 1, 2]
458-
# consumed is returned
459466
assert consumed == [0, 1, 2]
460467
assert len(sampled) == batch_size
461-
# State should be cleaned up
462-
assert sampler._states == {}
468+
# State is kept for other ranks to fetch
463469

464470
def test_rank_aware_sampler_second_rank_gets_cached(self):
465-
"""Test that second rank in DP group gets cached indices."""
471+
"""Test that second rank in data replica group gets cached indices."""
466472
sampler = RankAwareSampler()
467473
ready_indexes = [0, 1, 2, 3, 4, 5]
468474
batch_size = 3
469-
dp_world_size = 2
470-
world_size = 4 # Use world_size=4 so fetches_per_batch=2
471475

472-
# Rank 0 (dp_group=0) samples first
476+
# Rank 0 (first in group) samples first
473477
sampled1, consumed1 = sampler.sample(
474-
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
478+
ready_indexes,
479+
batch_size,
480+
data_replica_group=0,
481+
data_replica_rank=0,
482+
data_replica_world_size=2,
483+
task_name="task",
484+
partition_id="test",
475485
)
476486

477-
# Rank 1 (dp_group=0) should get same cached indices
487+
# Rank 1 (second in group) should get same cached indices
478488
sampled2, consumed2 = sampler.sample(
479-
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
489+
ready_indexes,
490+
batch_size,
491+
data_replica_group=0,
492+
data_replica_rank=1,
493+
data_replica_world_size=2,
494+
task_name="task",
495+
partition_id="test",
480496
)
481497

482498
assert sampled1 == sampled2 == [0, 1, 2]
483-
# First rank already returns consumed indexes
484499
assert consumed1 == [0, 1, 2]
485-
# Second rank also sees the same consumed indexes; state is then cleaned up
486500
assert consumed2 == [0, 1, 2]
487-
# State should be cleaned up
488-
assert sampler._states == {}
501+
502+
# cache should be empty after all ranks fetch
503+
assert len(sampler._states["test"]["task"][0][0]) == 0
504+
assert len(sampler._states["test"]["task"][0][1]) == 0
489505

490506
def test_rank_aware_sampler_multiple_dp_groups(self):
491-
"""Test that multiple DP groups work independently."""
507+
"""Test that multiple data replica groups work independently."""
492508
sampler = RankAwareSampler()
493509
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
494510
batch_size = 2
495-
dp_world_size = 4
496-
world_size = 8
511+
data_replica_world_size = 2 # Each group has 2 ranks
497512

498-
# DP group 0: rank 0 samples first
513+
# data replica group 0: rank 0 samples first
499514
sampled0_g0, consumed0_g0 = sampler.sample(
500-
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
515+
ready_indexes,
516+
batch_size,
517+
data_replica_group=0,
518+
data_replica_rank=0,
519+
data_replica_world_size=data_replica_world_size,
520+
task_name="task",
521+
partition_id="test",
501522
)
502523
# mimic the consumption status update managed in TransferQueueController
503524
ready_indexes = [i for i in ready_indexes if i not in consumed0_g0]
504525

505-
# DP group 1: rank 0 samples first
526+
# data replica group 1: rank 0 samples first
506527
sampled0_g1, consumed0_g1 = sampler.sample(
507-
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
528+
ready_indexes,
529+
batch_size,
530+
data_replica_group=1,
531+
data_replica_rank=0,
532+
data_replica_world_size=data_replica_world_size,
533+
task_name="task",
534+
partition_id="test",
508535
)
509536
ready_indexes = [i for i in ready_indexes if i not in consumed0_g1]
510537

@@ -514,47 +541,82 @@ def test_rank_aware_sampler_multiple_dp_groups(self):
514541
assert consumed0_g0 == [0, 1]
515542
assert consumed0_g1 == [2, 3]
516543

517-
# DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed
544+
# data replica group 0: rank 1 fetches cached
518545
sampled1_g0, consumed1_g0 = sampler.sample(
519-
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
546+
ready_indexes,
547+
batch_size,
548+
data_replica_group=0,
549+
data_replica_rank=1,
550+
data_replica_world_size=data_replica_world_size,
551+
task_name="task",
552+
partition_id="test",
520553
)
521554
ready_indexes = [i for i in ready_indexes if i not in consumed1_g0]
522555
assert sampled1_g0 == [0, 1]
523556
assert consumed1_g0 == [0, 1]
524557

525-
# DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed
558+
# data replica group 1: rank 1 fetches cached
526559
sampled1_g1, consumed1_g1 = sampler.sample(
527-
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
560+
ready_indexes,
561+
batch_size,
562+
data_replica_group=1,
563+
data_replica_rank=1,
564+
data_replica_world_size=data_replica_world_size,
565+
task_name="task",
566+
partition_id="test",
528567
)
529568
ready_indexes = [i for i in ready_indexes if i not in consumed1_g1]
530569
assert sampled1_g1 == [2, 3]
531570
assert consumed1_g1 == [2, 3]
532571

533-
# DP group 0: rank 0 fetches again, this should return new data
572+
# data replica group 0: rank 0 fetches again, this should return new data
534573
sampled2_g0, consumed2_g0 = sampler.sample(
535-
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
574+
ready_indexes,
575+
batch_size,
576+
data_replica_group=0,
577+
data_replica_rank=0,
578+
data_replica_world_size=data_replica_world_size,
579+
task_name="task",
580+
partition_id="test",
536581
)
537582
ready_indexes = [i for i in ready_indexes if i not in consumed2_g0]
538583
assert sampled2_g0 == [4, 5]
539584
assert consumed2_g0 == [4, 5]
540585

541-
# DP group 0: rank 1 fetches cached
586+
# data replica group 0: rank 1 fetches cached
542587
sampled3_g0, consumed3_g0 = sampler.sample(
543-
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
588+
ready_indexes,
589+
batch_size,
590+
data_replica_group=0,
591+
data_replica_rank=1,
592+
data_replica_world_size=data_replica_world_size,
593+
task_name="task",
594+
partition_id="test",
544595
)
545596
assert sampled3_g0 == [4, 5]
546597
assert consumed3_g0 == [4, 5]
547598

548-
# Both groups should be cleaned up
549-
assert sampler._states == {}
599+
# examine the internal state to ensure proper caching and clearing
600+
assert len(sampler._states["test"]["task"][0][0]) == 0
601+
assert len(sampler._states["test"]["task"][0][1]) == 0
602+
assert len(sampler._states["test"]["task"][1][0]) == 0
603+
assert len(sampler._states["test"]["task"][1][1]) == 0
550604

551605
def test_rank_aware_sampler_empty_ready_indexes(self):
552606
"""Test behavior with empty ready indexes."""
553607
sampler = RankAwareSampler()
554608
ready_indexes = []
555609
batch_size = 3
556610

557-
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
611+
sampled, consumed = sampler.sample(
612+
ready_indexes,
613+
batch_size,
614+
data_replica_group=0,
615+
data_replica_rank=0,
616+
data_replica_world_size=2,
617+
task_name="task",
618+
partition_id="test",
619+
)
558620

559621
assert sampled == []
560622
assert consumed == []
@@ -565,8 +627,15 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self):
565627
ready_indexes = [0, 1]
566628
batch_size = 5
567629

568-
# When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately
569-
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
630+
sampled, consumed = sampler.sample(
631+
ready_indexes,
632+
batch_size,
633+
data_replica_group=0,
634+
data_replica_rank=0,
635+
data_replica_world_size=2,
636+
task_name="task",
637+
partition_id="test",
638+
)
570639

571640
assert sampled == []
572641
assert consumed == []
@@ -577,11 +646,112 @@ def test_rank_aware_sampler_zero_batch_size(self):
577646
ready_indexes = [0, 1, 2, 3]
578647
batch_size = 0
579648

580-
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
649+
sampled, consumed = sampler.sample(
650+
ready_indexes,
651+
batch_size,
652+
data_replica_group=0,
653+
data_replica_rank=0,
654+
data_replica_world_size=2,
655+
task_name="task",
656+
partition_id="test",
657+
)
581658

582659
assert sampled == []
583660
assert consumed == []
584661

662+
def test_rank_aware_sampler_data_prefetch(self):
663+
"""Test behavior with data prefetch."""
664+
sampler = RankAwareSampler()
665+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
666+
batch_size = 2
667+
668+
sampled_rank0_time0, consumed_rank0_time0 = sampler.sample(
669+
ready_indexes,
670+
batch_size,
671+
data_replica_group=0,
672+
data_replica_rank=0,
673+
data_replica_world_size=2,
674+
task_name="task",
675+
partition_id="test",
676+
)
677+
678+
assert sampled_rank0_time0 == [0, 1]
679+
assert consumed_rank0_time0 == [0, 1]
680+
assert sampler._states["test"]["task"][0][0] == []
681+
assert sampler._states["test"]["task"][0][1] == [[0, 1]]
682+
683+
ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0]
684+
685+
sampled_rank0_time1, consumed_rank0_time1 = sampler.sample(
686+
ready_indexes,
687+
batch_size,
688+
data_replica_group=0,
689+
data_replica_rank=0,
690+
data_replica_world_size=2,
691+
task_name="task",
692+
partition_id="test",
693+
)
694+
695+
assert sampled_rank0_time1 == [2, 3]
696+
assert consumed_rank0_time1 == [2, 3]
697+
assert sampler._states["test"]["task"][0][0] == []
698+
assert sampler._states["test"]["task"][0][1] == [[0, 1], [2, 3]]
699+
700+
ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1]
701+
702+
sampled_rank1_time0, consumed_rank1_time0 = sampler.sample(
703+
ready_indexes,
704+
batch_size,
705+
data_replica_group=0,
706+
data_replica_rank=1,
707+
data_replica_world_size=2,
708+
task_name="task",
709+
partition_id="test",
710+
)
711+
assert sampled_rank1_time0 == [0, 1]
712+
assert consumed_rank1_time0 == [0, 1]
713+
714+
assert sampler._states["test"]["task"][0][0] == []
715+
assert sampler._states["test"]["task"][0][1] == [[2, 3]]
716+
717+
def test_rank_aware_sampler_multiple_tasks(self):
718+
"""Test behavior with multiple tasks."""
719+
sampler = RankAwareSampler()
720+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
721+
batch_size = 2
722+
723+
sampled_rank0_task0, consumed_rank0_task0 = sampler.sample(
724+
ready_indexes,
725+
batch_size,
726+
data_replica_group=0,
727+
data_replica_rank=0,
728+
data_replica_world_size=2,
729+
task_name="task0",
730+
partition_id="test",
731+
)
732+
733+
assert sampled_rank0_task0 == [0, 1]
734+
assert consumed_rank0_task0 == [0, 1]
735+
assert sampler._states["test"]["task0"][0][0] == []
736+
assert sampler._states["test"]["task0"][0][1] == [[0, 1]]
737+
738+
sampled_rank0_task1, consumed_rank0_task1 = sampler.sample(
739+
ready_indexes,
740+
batch_size,
741+
data_replica_group=0,
742+
data_replica_rank=0,
743+
data_replica_world_size=2,
744+
task_name="task1",
745+
partition_id="test",
746+
)
747+
748+
assert sampled_rank0_task1 == [0, 1]
749+
assert consumed_rank0_task1 == [0, 1]
750+
assert sampler._states["test"]["task0"][0][0] == []
751+
assert sampler._states["test"]["task0"][0][1] == [[0, 1]]
752+
assert sampler._states["test"]["task1"][0][0] == []
753+
assert sampler._states["test"]["task1"][0][1] == [[0, 1]]
754+
585755

586756
class TestSamplerIntegration:
587757
"""Integration tests for samplers."""

0 commit comments

Comments
 (0)