Skip to content

Commit 4424e2d

Browse files
authored
[StreamingDataLoader, 1/N] feat: implement RankAwareSampler (#4)
## Background This PR is the first in a series [1/N] to introduce `StreamingDataLoader`, a mechanism designed to optimize data dispatch in distributed training. Specifically, this PR implements `RankAwareSampler`. In distributed data parallel (DP) scenarios where ranks retrieve data independently, this sampler ensures deterministic behavior: **it guarantees that all ranks within the same DP group receive identical sample indices**, synchronizing the data consumption process. Leveraging `StreamingDataLoader`, we can supports micro-batch level pipelining for training backends. By passing the dataloader instance directly into `forward_backward_func`, we avoid the bottleneck of retrieving full mini-batches in advance. This allows for highly efficient, fine-grained streaming throughout the training process. ```python3 data_iter = StreamingDataLoader() losses_reduced = self.forward_backward_func( forward_step_func=forward_step, data_iterator=data_iter, model=self.model, num_microbatches=num_microbatches, seq_length=self.seq_length, micro_batch_size=self.micro_batch_size, forward_only=forward_only, collect_non_loss_data=forward_only, ) ``` Please refer to our roadmap for more details: [[Roadmap] StreamingDataLoader for task-separated RL post-training](#1) <img width="1853" height="879" alt="image" src="https://github.com/user-attachments/assets/ecc891e2-dca2-407b-b194-fb4a7ddaf1cc" /> ## Note We have added `Co-authored-by` credits to the commit messages to properly attribute the work to the early developers from https://github.com/TransferQueue/TransferQueue.
2 parents b928421 + 3b6eff2 commit 4424e2d

5 files changed

Lines changed: 298 additions & 6 deletions

File tree

tests/test_samplers.py

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515

1616
"""Unit tests for TransferQueue samplers."""
1717

18+
import sys
19+
from pathlib import Path
1820
from typing import Any
1921

2022
import pytest
2123

22-
from transfer_queue.sampler import BaseSampler
23-
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler
24-
from transfer_queue.sampler.sequential_sampler import SequentialSampler
24+
# Setup path
25+
parent_dir = Path(__file__).resolve().parent.parent
26+
sys.path.append(str(parent_dir))
27+
28+
from transfer_queue.sampler import BaseSampler # noqa: E402
29+
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler # noqa: E402
30+
from transfer_queue.sampler.rank_aware_sampler import RankAwareSampler # noqa: E402
31+
from transfer_queue.sampler.sequential_sampler import SequentialSampler # noqa: E402
2532

2633

2734
class TestBaseSampler:
@@ -427,6 +434,155 @@ def test_grpo_sampler_insufficient_groups(self):
427434
assert consumed == []
428435

429436

437+
class TestRankAwareSampler:
438+
"""Test cases for RankAwareSampler."""
439+
440+
def test_rank_aware_sampler_initialization(self):
441+
"""Test RankAwareSampler initialization."""
442+
sampler = RankAwareSampler()
443+
assert isinstance(sampler, BaseSampler)
444+
assert hasattr(sampler, "_states")
445+
assert sampler._states == {}
446+
447+
def test_rank_aware_sampler_first_rank_sampling(self):
448+
"""Test that first rank in DP group performs actual sampling."""
449+
sampler = RankAwareSampler()
450+
ready_indexes = [0, 1, 2, 3, 4, 5]
451+
batch_size = 3
452+
453+
# When world_size == dp_world_size, fetches_per_batch = 1
454+
# First rank samples and immediately marks consumed (no other ranks to wait for)
455+
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
456+
457+
assert sampled == [0, 1, 2]
458+
# consumed is returned
459+
assert consumed == [0, 1, 2]
460+
assert len(sampled) == batch_size
461+
# State should be cleaned up
462+
assert sampler._states == {}
463+
464+
def test_rank_aware_sampler_second_rank_gets_cached(self):
465+
"""Test that second rank in DP group gets cached indices."""
466+
sampler = RankAwareSampler()
467+
ready_indexes = [0, 1, 2, 3, 4, 5]
468+
batch_size = 3
469+
dp_world_size = 2
470+
world_size = 4 # Use world_size=4 so fetches_per_batch=2
471+
472+
# Rank 0 (dp_group=0) samples first
473+
sampled1, consumed1 = sampler.sample(
474+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
475+
)
476+
477+
# Rank 1 (dp_group=0) should get same cached indices
478+
sampled2, consumed2 = sampler.sample(
479+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
480+
)
481+
482+
assert sampled1 == sampled2 == [0, 1, 2]
483+
# First rank already returns consumed indexes
484+
assert consumed1 == [0, 1, 2]
485+
# Second rank also sees the same consumed indexes; state is then cleaned up
486+
assert consumed2 == [0, 1, 2]
487+
# State should be cleaned up
488+
assert sampler._states == {}
489+
490+
def test_rank_aware_sampler_multiple_dp_groups(self):
491+
"""Test that multiple DP groups work independently."""
492+
sampler = RankAwareSampler()
493+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
494+
batch_size = 2
495+
dp_world_size = 4
496+
world_size = 8
497+
498+
# DP group 0: rank 0 samples first
499+
sampled0_g0, consumed0_g0 = sampler.sample(
500+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
501+
)
502+
# mimic the consumption status update managed in TransferQueueController
503+
ready_indexes = [i for i in ready_indexes if i not in consumed0_g0]
504+
505+
# DP group 1: rank 0 samples first
506+
sampled0_g1, consumed0_g1 = sampler.sample(
507+
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
508+
)
509+
ready_indexes = [i for i in ready_indexes if i not in consumed0_g1]
510+
511+
# Both should have sampled their first batch
512+
assert sampled0_g0 == [0, 1]
513+
assert sampled0_g1 == [2, 3]
514+
assert consumed0_g0 == [0, 1]
515+
assert consumed0_g1 == [2, 3]
516+
517+
# DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed
518+
sampled1_g0, consumed1_g0 = sampler.sample(
519+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
520+
)
521+
ready_indexes = [i for i in ready_indexes if i not in consumed1_g0]
522+
assert sampled1_g0 == [0, 1]
523+
assert consumed1_g0 == [0, 1]
524+
525+
# DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed
526+
sampled1_g1, consumed1_g1 = sampler.sample(
527+
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
528+
)
529+
ready_indexes = [i for i in ready_indexes if i not in consumed1_g1]
530+
assert sampled1_g1 == [2, 3]
531+
assert consumed1_g1 == [2, 3]
532+
533+
# DP group 0: rank 0 fetches again, this should return new data
534+
sampled2_g0, consumed2_g0 = sampler.sample(
535+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
536+
)
537+
ready_indexes = [i for i in ready_indexes if i not in consumed2_g0]
538+
assert sampled2_g0 == [4, 5]
539+
assert consumed2_g0 == [4, 5]
540+
541+
# DP group 0: rank 1 fetches cached
542+
sampled3_g0, consumed3_g0 = sampler.sample(
543+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
544+
)
545+
assert sampled3_g0 == [4, 5]
546+
assert consumed3_g0 == [4, 5]
547+
548+
# Both groups should be cleaned up
549+
assert sampler._states == {}
550+
551+
def test_rank_aware_sampler_empty_ready_indexes(self):
552+
"""Test behavior with empty ready indexes."""
553+
sampler = RankAwareSampler()
554+
ready_indexes = []
555+
batch_size = 3
556+
557+
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
558+
559+
assert sampled == []
560+
assert consumed == []
561+
562+
def test_rank_aware_sampler_batch_size_larger_than_ready(self):
563+
"""Test behavior when batch_size > len(ready_indexes)."""
564+
sampler = RankAwareSampler()
565+
ready_indexes = [0, 1]
566+
batch_size = 5
567+
568+
# When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately
569+
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
570+
571+
assert sampled == []
572+
assert consumed == []
573+
574+
def test_rank_aware_sampler_zero_batch_size(self):
575+
"""Test behavior with zero batch size."""
576+
sampler = RankAwareSampler()
577+
ready_indexes = [0, 1, 2, 3]
578+
batch_size = 0
579+
580+
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
581+
582+
assert sampled == []
583+
assert consumed == []
584+
585+
430586
class TestSamplerIntegration:
431587
"""Integration tests for samplers."""
432588

transfer_queue/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .metadata import BatchMeta
2525
from .sampler import BaseSampler
2626
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
27+
from .sampler.rank_aware_sampler import RankAwareSampler
2728
from .sampler.sequential_sampler import SequentialSampler
2829
from .storage import SimpleStorageUnit
2930
from .utils.utils import get_placement_group
@@ -41,6 +42,7 @@
4142
"BaseSampler",
4243
"GRPOGroupNSampler",
4344
"SequentialSampler",
45+
"RankAwareSampler",
4446
]
4547

4648
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))

transfer_queue/sampler/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from .base import BaseSampler
1717
from .grpo_group_n_sampler import GRPOGroupNSampler
18+
from .rank_aware_sampler import RankAwareSampler
1819
from .sequential_sampler import SequentialSampler
1920

20-
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler"]
21+
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"]

transfer_queue/sampler/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ class BaseSampler(ABC):
3434
- **SequentialSampler**: Default sampler, selects samples sequentially without replacement
3535
- **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready.
3636
It assumes the N samples associated with the same prompt are stored contiguously
37-
- **RankAwareSampler**: Rank-aware sampling for distributed scenarios (TODO)
37+
- **RankAwareSampler**: Rank-aware sampling for distributed training where each rank retrieves data independently.
38+
This sampler will guarantee ranks of the same DP group consume identical samples.
3839
3940
NOTE: Always return both sampled and consumed indexes (may be identical).
4041
"""
4142

4243
def __init__(self):
43-
self._states: dict[str, Any] = {}
44+
self._states: dict[Any, Any] = {}
4445

4546
@abstractmethod
4647
def sample(
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any
17+
18+
from transfer_queue.sampler import BaseSampler
19+
20+
21+
class RankAwareSampler(BaseSampler):
22+
"""Rank-aware sampler for distributed training with TransferQueue.
23+
24+
This sampler is designed for distributed data parallel training scenarios
25+
where each rank retrieves data independently.
26+
27+
This sampler guarantees that all ranks within the same DP group receive
28+
the same sample indices.
29+
30+
The sampler maintains per-DP-group state to coordinate sampling across ranks:
31+
32+
- First rank in a DP group to call :meth:`sample` performs actual sampling from
33+
``ready_indexes`` and caches the result
34+
- Subsequent ranks in the same DP group retrieve the cached indices
35+
- Once all ranks in the DP group have fetched their samples, the cached state is
36+
cleaned up.
37+
38+
39+
Please refer to our roadmap for more details:
40+
[Roadmap] StreamingDataLoader for task-separated RL post-training
41+
https://github.com/Ascend/TransferQueue/issues/1
42+
"""
43+
44+
def __init__(self):
45+
"""Initialize the RankAwareSampler.
46+
47+
The sampler maintains internal state to coordinate sampling across ranks
48+
within the same DP group. This state tracks which samples have been sampled
49+
and how many times they have been fetched.
50+
"""
51+
super().__init__()
52+
53+
def sample(
54+
self,
55+
ready_indexes: list[int],
56+
batch_size: int,
57+
dp_group: int,
58+
dp_world_size: int,
59+
world_size: int,
60+
*args: Any,
61+
**kwargs: Any,
62+
) -> tuple[list[int], list[int]]:
63+
"""Sample indices for the current rank, coordinating with other DP ranks.
64+
65+
This method implements coordinated sampling for distributed training.
66+
The first rank in each DP group to call this method performs actual sampling
67+
from ``ready_indexes`` and caches the result. Subsequent ranks in the same
68+
DP group receive the cached indices directly.
69+
70+
Args:
71+
ready_indexes: List of global indices for which all required fields of the
72+
corresponding samples have been produced, and the samples are not labeled
73+
as consumed in the corresponding task.
74+
batch_size: Number of samples to select. If larger than available
75+
ready samples, all available samples will be returned.
76+
dp_group: The group id of current data parallel group. Used to
77+
identify which DP group this rank belongs to.
78+
dp_world_size: Number of ranks in the data parallelism group. Used to
79+
determine when all ranks have fetched their samples.
80+
world_size: Total number of ranks across all parallelism dimensions.
81+
Used to determine when all ranks have fetched their samples.
82+
*args: Additional positional arguments (ignored).
83+
**kwargs: Additional keyword arguments (ignored).
84+
85+
Returns:
86+
List of sampled global indices. Typically, has length `batch_size`,
87+
or returns an empty list if samples are insufficient.
88+
89+
List of global indices that should be labeled as consumed
90+
(will never be retrieved by other dp_groups in the future).
91+
92+
Raises:
93+
RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``.
94+
"""
95+
96+
# Check if this DP group already has sampled data cached
97+
data_for_dp_group = self._states.get(dp_group, None)
98+
99+
# Calculate how many times this batch should be fetched across all ranks
100+
if dp_world_size <= 0 or world_size % dp_world_size != 0:
101+
raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})")
102+
103+
fetches_per_batch = world_size // dp_world_size
104+
105+
if data_for_dp_group is None:
106+
# Select first batch_size indices from ready_indexes
107+
sampled_indexes = ready_indexes[:batch_size]
108+
109+
if len(sampled_indexes) < batch_size:
110+
return [], []
111+
112+
# Initialize state for this DP group
113+
self._states[dp_group] = {}
114+
consumed_indexes = sampled_indexes
115+
116+
# Cache the sampled indices for other ranks in this DP group
117+
self._states[dp_group]["index"] = sampled_indexes
118+
self._states[dp_group]["fetch_count"] = 1
119+
120+
else:
121+
# Return the cached indices (identical to what first rank received)
122+
sampled_indexes = self._states[dp_group]["index"]
123+
consumed_indexes = self._states[dp_group]["index"]
124+
125+
# Increment fetch count to track progress
126+
self._states[dp_group]["fetch_count"] += 1
127+
128+
# Check if this was the last rank in the DP group to fetch
129+
if self._states[dp_group]["fetch_count"] >= fetches_per_batch:
130+
del self._states[dp_group]
131+
132+
return sampled_indexes, consumed_indexes

0 commit comments

Comments
 (0)